-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathprepro_util.py
executable file
·166 lines (143 loc) · 5.67 KB
/
prepro_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np
import unicodedata
import tokenization
from collections import Counter, defaultdict
class SquadExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self,
qas_id,
question_text,
doc_tokens,
paragraph_indices=None,
orig_answer_text=None,
all_answers=None,
start_position=None,
end_position=None,
switch=None):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.paragraph_indices = paragraph_indices
self.orig_answer_text = orig_answer_text
self.all_answers=all_answers
self.start_position = start_position
self.end_position = end_position
self.switch = switch
def __str__(self):
return self.__repr__()
def __repr__(self):
s = "question: "+self.question_text
return s
class InputFeatures(object):
def __init__(self,
unique_id,
example_index,
paragraph_index=None,
doc_span_index=None,
doc_tokens=None,
tokens=None,
token_to_orig_map=None,
token_is_max_context=None,
input_ids=None,
input_mask=None,
segment_ids=None,
start_position=None,
end_position=None,
switch=None,
answer_mask=None):
self.unique_id = unique_id
self.example_index = example_index
self.paragraph_index = paragraph_index
self.doc_span_index = doc_span_index
self.doc_tokens = doc_tokens
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.switch = switch
self.answer_mask = answer_mask
def _run_strip_accents(text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def find_span_from_text(context, tokens, answer):
assert answer in context
offset = 0
spans = []
scanning = None
process = []
for i, token in enumerate(tokens):
token = token.replace(' ##', '').replace('##', '')
while context[offset:offset+len(token)]!=token:
offset += 1
if offset >= len(context):
break
if scanning is not None:
end = offset + len(token)
if answer.startswith(context[scanning[-1][-1]:end]):
if context[scanning[-1][-1]:end] == answer:
span = (scanning[0][0], i, scanning[0][1])
spans.append(span)
elif len(context[scanning[-1][-1]:end]) >= len(answer):
scanning = None
else:
scanning = None
if scanning is None and answer.startswith(token):
if token == answer:
spans.append((i, i, offset))
if token != answer:
scanning = [(i, offset)]
offset += len(token)
if offset >= len(context):
break
process.append((token, offset, scanning, spans))
answers = []
for word_start, word_end, span in spans:
assert context[span:span+len(answer)]==answer or ''.join(tokens[word_start:word_end+1]).replace('##', '')!=answer.replace(' ', '')
answers.append({'text': answer, 'answer_start': span, 'word_start': word_start, 'word_end': word_end})
return answers
def detect_span(_answers, context, doc_tokens, char_to_word_offset):
orig_answer_texts = []
start_positions = []
end_positions = []
switches = []
answers = []
for answer in _answers:
answers += find_span_from_text(context, doc_tokens, answer['text'])
for answer in answers:
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
answer_length = len(orig_answer_text)
switch = 0
if 'word_start' in answer and 'word_end' in answer:
start_position = answer['word_start']
end_position = answer['word_end']
else:
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset + answer_length - 1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]).replace(' ##', '').replace('##', '')
cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.replace(' ', '').find(cleaned_answer_text.replace(' ', '')) == -1:
print ("Could not find answer: '%s' vs. '%s'" % (actual_text, cleaned_answer_text))
orig_answer_texts.append(orig_answer_text)
start_positions.append(start_position)
end_positions.append(end_position)
switches.append(switch)
return orig_answer_texts, switches, start_positions, end_positions