-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrelation_f1.py
157 lines (133 loc) · 5.49 KB
/
relation_f1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from collections import Counter
from sciriff.eval.metrics import util
class RelationF1:
"Exact-match relation F1."
def __init__(self, tup_len, overlap_thresh):
self.tup_len = tup_len # Relation tuple length.
self.overlap_thresh = overlap_thresh
@staticmethod
def _normalize(xs):
"Normalize a single predicted list by converting to lowercase strings."
try:
res = [util.normalize_list_entry(x) for x in xs]
return res
except TypeError:
return None
def _parse_entry(self, entry):
"Parse a single entry in the model's list of predictions."
if not isinstance(entry, list):
self.counts_parse["entry_not_list"] += 1
return None
elif len(entry) == self.tup_len:
res = self._normalize(entry)
if res is not None:
self.counts_parse["good_list"] += 1
return tuple(res)
else:
self.counts_parse["bad_list"] += 1
return None
else:
self.counts_parse["wrong_length"] += 1
return None
def _parse_prediction(self, pred):
"""
Parse a prediction, counting errors if they occur.
"""
if pred == []:
return []
res = []
try:
is_list_of_lists = isinstance(pred[0], list)
except TypeError:
# This will trigger if the prediction is an int or other subscriptable
# datatype.
return []
if is_list_of_lists:
# If the first entry in the prediction is a list, attempt to parse each
# entry. Throw out those that aren't lists, or that have the wrong length.
for entry in pred:
parsed = self._parse_entry(entry)
if parsed is not None:
res.append(parsed)
else:
# Otherwise, maybe the model just returned a non-nested list; use it if
# possible.
parsed = self._parse_entry(pred)
if parsed is not None:
res.append(parsed)
return res
def _evaluate_f1_exact(self, pred, ref):
"Count exact matches."
pred_set = set(pred)
ref_set = set(ref)
self.counts_score["exact_match"] += len(pred_set & ref_set)
def _is_substring_match(self, pred_entry, ref_entry):
"Return True if every pred element is a sub / superstring of ref."
if len(pred_entry) != self.tup_len or len(ref_entry) != self.tup_len:
raise ValueError("Unexpected entry length.")
for pred_item, ref_item in zip(pred_entry, ref_entry):
if pred_item == "":
return False
if pred_item not in ref_item and ref_item not in pred_item:
return False
return True
def _is_overlap(self, pred_entry, ref_entry):
"Return True if every element has token F1 of at least `self.overlap_thresh`."
if len(pred_entry) != self.tup_len or len(ref_entry) != self.tup_len:
raise ValueError("Unexpected entry length.")
for pred_item, ref_item in zip(pred_entry, ref_entry):
if pred_item == "":
return False
if util.compute_token_f1(pred_item, ref_item) < self.overlap_thresh:
return False
return True
def _evaluate_f1_fuzzy(self, sim_fn, sim_name, pred, ref):
"Compute fuzzy F1 score given a similarity function `sim_fn`."
already_used_refs = set()
for pred_entry in pred:
for ref_entry in ref:
if ref_entry in already_used_refs:
continue
if sim_fn(pred_entry, ref_entry):
self.counts_score[sim_name] += 1
already_used_refs.add(ref_entry)
# Continue on to the next prediction so we don't double-count.
break
def _evaluate_one(self, instance):
pred = self._parse_prediction(instance["pred"])
ref = [tuple(self._normalize(entry)) for entry in instance["ref"]]
self.counts_score["preds"] += len(pred)
self.counts_score["refs"] += len(ref)
self._evaluate_f1_exact(pred, ref)
self._evaluate_f1_fuzzy(self._is_substring_match, "substring_match", pred, ref)
self._evaluate_f1_fuzzy(self._is_overlap, "overlap_match", pred, ref)
def evaluate(self, instances):
self.counts_parse = util.count_dict(
["entry_not_list", "good_list", "bad_list", "wrong_length"]
)
self.counts_score = util.count_dict(
["preds", "refs", "exact_match", "substring_match", "overlap_match"]
)
for instance in instances:
self._evaluate_one(instance)
res = {}
res["exact"] = util.compute_f1(
self.counts_score["exact_match"],
self.counts_score["preds"],
self.counts_score["refs"],
)
res["substring"] = util.compute_f1(
self.counts_score["substring_match"],
self.counts_score["preds"],
self.counts_score["refs"],
)
res["overlap"] = util.compute_f1(
self.counts_score["overlap_match"],
self.counts_score["preds"],
self.counts_score["refs"],
)
self.counts_parse["frac_success"] = util.safe_div(
self.counts_parse["good_list"], util.sum_dict(self.counts_parse)
)
res["parse_counts"] = self.counts_parse
return res