-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils_e2e.py
94 lines (75 loc) · 2.79 KB
/
utils_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Utilities.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import texar as tx
from tensorflow.contrib.seq2seq import tile_batch
from data2text.data_utils import get_train_ents, extract_entities, extract_numbers
# load all entities
#all_ents, players, teams, cities = get_train_ents(path=os.path.join("data2text", "rotowire"), connect_multiwords=True)
all_ents = set()
with open('e2e_data_v14/e2e.entry.vocab.txt', 'r') as f:
all_vocb = f.readlines()
for vocab in all_vocb:
all_ents.add(vocab.strip('\n'))
get_scope_name_of_train_op = 'train_{}'.format
get_scope_name_of_summary_op = 'summary_{}'.format
sent_fields = ['sent']
sd_fields = ['entry', 'attribute', 'value']
all_fields = sent_fields + sd_fields
ref_strs = ['', '_ref']
class DataItem(collections.namedtuple('DataItem', sd_fields)):
def __str__(self):
return '|'.join(map(str, self))
def pack_sd(paired_texts):
return [DataItem(*_) for _ in zip(*paired_texts)]
def batchize(func):
def batchized_func(*inputs):
return [func(*paired_inputs) for paired_inputs in zip(*inputs)]
return batchized_func
def strip_special_tokens_of_list(text):
return tx.utils.strip_special_tokens(text, is_token_list=True)
batch_strip_special_tokens_of_list = batchize(strip_special_tokens_of_list)
def replace_data_in_sent(sent, token="<UNK>"):
datas = extract_entities(sent, all_ents)
datas.sort(key=lambda data: data.start, reverse=True)
for data in datas:
#print('=============data is :{}'.format(data))
sent[data.start] = token
return sent
def corpus_bleu(list_of_references, hypotheses, **kwargs):
list_of_references = [
list(map(replace_data_in_sent, refs))
for refs in list_of_references]
hypotheses = list(map(replace_data_in_sent, hypotheses))
return tx.evals.corpus_bleu_moses(
list_of_references, hypotheses,
lowercase=True, return_all=False,
**kwargs)
def read_sents_from_file(file_name):
with open(file_name, 'r') as f:
return list(map(str.split, f))
def read_x(data_prefix, ref_flag, stage):
ref_str = ref_strs[ref_flag]
return list(map(
lambda paired_sents: list(map(
lambda tup: DataItem(*tup),
zip(*paired_sents))),
zip(*map(
lambda field: read_sents_from_file(
'{}{}{}.{}.txt'.format(data_prefix, field, ref_str, stage)),
sd_fields))))
def read_y(data_prefix, ref_flag, stage):
ref_str = ref_strs[ref_flag]
field = sent_fields[0]
return read_sents_from_file(
'{}{}{}.{}.txt'.format(data_prefix, field, ref_str, stage))
def divide_or_const(a, b, c=0.):
try:
return a / b
except ZeroDivisionError:
return c