-
Notifications
You must be signed in to change notification settings - Fork 18
/
coref_util.py
217 lines (174 loc) · 6.73 KB
/
coref_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from typing import List, Tuple, Dict
from thinc.types import Ints1d, Ints2d, Floats2d
from thinc.api import NumpyOps
from spacy.language import Language
from spacy.tokens import Doc
# type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]]
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
DEFAULT_CLUSTER_HEAD_PREFIX = "coref_head_clusters"
@Language.factory(
"experimental_span_cleaner",
assigns=["doc.spans"],
default_config={"prefix": DEFAULT_CLUSTER_HEAD_PREFIX},
)
def make_span_cleaner(nlp: Language, name: str, *, prefix: str) -> "SpanCleaner":
"""Create a span cleaner component.
Given a prefix, a span cleaner removes any spans on the Doc where the key
matches the prefix.
"""
return SpanCleaner(prefix)
class SpanCleaner:
def __init__(self, prefix: str):
self.prefix = prefix
def __call__(self, doc: Doc) -> Doc:
for key in list(doc.spans.keys()):
if key.startswith(self.prefix):
del doc.spans[key]
return doc
def matches_coref_prefix(prefix: str, key: str) -> bool:
"""Check if a span key matches a coref prefix.
Given prefix "xxx", "xxx_1" is a matching span, but "xxx_yyy" and
"xxx_yyy_1" are not matching spans. The prefix must only be followed by an
underscore and an integer.
"""
if not key.startswith(prefix):
return False
# remove the "prefix_" bit
suffix = key[len(prefix) + 1 :]
try:
int(suffix)
except ValueError:
return False
return True
def get_sentence_ids(doc: Doc) -> List[int]:
"""Given a Doc, return a list of the sentence ID of each token,
where the sentence ID is the index of the sentence in the Doc.
Used in coref to make sure mentions don't cross sentence boundaries.
"""
out = []
sent_id = -1
for tok in doc:
if tok.is_sent_start:
sent_id += 1
out.append(sent_id)
return out
# from model.py, refactored to be non-member
def get_predicted_antecedents(xp, antecedent_idx: Ints2d, antecedent_scores: Floats2d):
"""Get the ID of the antecedent for each span. -1 if no antecedent."""
predicted_antecedents = xp.argmax(antecedent_scores, axis=1) - 1
out = xp.full(antecedent_idx.shape[0], -1, dtype=antecedent_idx.dtype)
if predicted_antecedents.max() == -1:
return out
valid_indices = predicted_antecedents != -1
valid_preds = antecedent_idx[
xp.arange(antecedent_idx.shape[0]), predicted_antecedents
][valid_indices]
xp.place(
out,
valid_indices,
valid_preds,
)
return out
# from model.py, refactored to be non-member
def get_predicted_clusters(
xp,
span_starts: Ints1d,
span_ends: Ints1d,
antecedent_idx: Ints2d,
antecedent_scores: Floats2d,
):
"""Convert predictions to usable cluster data.
return values:
clusters: a list of spans (i, j) that are a cluster
Note that not all spans will be in the final output; spans with no
antecedent or referrent are omitted from clusters and mention2cluster.
"""
# Get predicted antecedents
ops = NumpyOps()
predicted_antecedents = ops.asarray(
get_predicted_antecedents(xp, antecedent_idx, antecedent_scores)
).tolist()
# Get predicted clusters
mention_to_cluster_id = {}
predicted_clusters = []
for i, predicted_idx in enumerate(predicted_antecedents):
if predicted_idx < 0:
continue
assert i > predicted_idx, f"span idx: {i}; antecedent idx: {predicted_idx}"
# Check antecedent's cluster
antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx]))
antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1)
if antecedent_cluster_id == -1:
antecedent_cluster_id = len(predicted_clusters)
predicted_clusters.append([antecedent])
mention_to_cluster_id[antecedent] = antecedent_cluster_id
# Add mention to cluster
mention = (int(span_starts[i]), int(span_ends[i]))
predicted_clusters[antecedent_cluster_id].append(mention)
mention_to_cluster_id[mention] = antecedent_cluster_id
predicted_clusters = [tuple(c) for c in predicted_clusters]
return predicted_clusters
def create_head_span_idxs(ops, doclen: int):
"""Helper function to create single-token span indices."""
aa = ops.xp.arange(0, doclen)
bb = ops.xp.arange(0, doclen) + 1
return ops.asarray2i([aa, bb]).T
def get_clusters_from_doc(
doc: Doc, *, use_heads: bool = False, prefix: str = None
) -> List[List[Tuple[int, int]]]:
"""Convert the span clusters in a Doc to simple integer tuple lists. The
ints are char spans, to be tokenization independent.
If `use_heads` is True, then the heads (roots) of the spans will be used.
If a `prefix` is provided, then only spans matching the prefix will be used.
"""
out = []
keys = sorted(list(doc.spans.keys()))
for key in keys:
if prefix is not None and not matches_coref_prefix(prefix, key):
continue
val = doc.spans[key]
cluster = []
for span in val:
if use_heads:
head_i = span.root.i
head = doc[head_i]
char_span = (head.idx, head.idx + len(head))
else:
char_span = (span[0].idx, span[-1].idx + len(span[-1]))
cluster.append(char_span)
# don't want duplicates
cluster = list(set(cluster))
out.append(cluster)
return out
def create_gold_scores(
ments: Ints2d, clusters: List[List[Tuple[int, int]]]
) -> Floats2d:
"""Given mentions considered for antecedents and gold clusters,
construct a gold score matrix. This does not include the placeholder.
In the gold matrix, the value of a true antecedent is True, and otherwise
it is False. These will represented as 1/0 values.
"""
# make a mapping of mentions to cluster id
# id is not important but equality will be
ment2cid: Dict[Tuple[int, int], int] = {}
for cid, cluster in enumerate(clusters):
for ment in cluster:
ment2cid[ment] = cid
ll = len(ments)
ops = NumpyOps()
cpu_ments = ops.asarray(ments)
out = ops.alloc2f(ll, ll)
for ii, ment in enumerate(cpu_ments):
cid = ment2cid.get((int(ment[0]), int(ment[1])))
if cid is None:
# this is not in a cluster so it has no antecedent
continue
# this might change if no real antecedent is a candidate
for jj, ante in enumerate(cpu_ments):
# antecedents must come first
if jj >= ii:
break
if cid == ment2cid.get((int(ante[0]), int(ante[1])), -1):
out[ii, jj] = 1.0
return out