-
Notifications
You must be signed in to change notification settings - Fork 2
/
non_rg_metrics.py
136 lines (117 loc) · 4.47 KB
/
non_rg_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
# usage:
python non_rg_metrics.py gold_tuple_fi pred_tuple_fi
"""
from pyxdameraulevenshtein import normalized_damerau_levenshtein_distance
import sys
full_names = ['Atlanta Hawks', 'Boston Celtics', 'Brooklyn Nets',
'Charlotte Hornets', 'Chicago Bulls', 'Cleveland Cavaliers',
'Detroit Pistons', 'Indiana Pacers', 'Miami Heat',
'Milwaukee Bucks', 'New York Knicks', 'Orlando Magic',
'Philadelphia 76ers', 'Toronto Raptors', 'Washington Wizards',
'Dallas Mavericks', 'Denver Nuggets', 'Golden State Warriors',
'Houston Rockets', 'Los Angeles Clippers', 'Los Angeles Lakers',
'Memphis Grizzlies', 'Minnesota Timberwolves',
'New Orleans Pelicans', 'Oklahoma City Thunder', 'Phoenix Suns',
'Portland Trail Blazers', 'Sacramento Kings', 'San Antonio Spurs',
'Utah Jazz']
cities, teams = set(), set()
ec = {} # equivalence classes
for team in full_names:
pieces = team.split()
if len(pieces) == 2:
ec[team] = [pieces[0], pieces[1]]
cities.add(pieces[0])
teams.add(pieces[1])
elif pieces[0] == "Portland": # only 2-word team
ec[team] = [pieces[0], " ".join(pieces[1:])]
cities.add(pieces[0])
teams.add(" ".join(pieces[1:]))
else: # must be a 2-word City
ec[team] = [" ".join(pieces[:2]), pieces[2]]
cities.add(" ".join(pieces[:2]))
teams.add(pieces[2])
def same_ent(e1, e2):
if e1 in cities or e1 in teams:
return e1 == e2 or any((e1 in fullname and e2 in fullname for fullname in full_names))
else:
return e1 in e2 or e2 in e1
def trip_match(t1, t2):
return t1[1] == t2[1] and t1[2] == t2[2] and same_ent(t1[0], t2[0])
def dedup_triples(triplist):
"""
this will be inefficient but who cares
"""
dups = set()
for i in range(1, len(triplist)):
for j in range(i):
if trip_match(triplist[i], triplist[j]):
dups.add(i)
break
return [thing for i, thing in enumerate(triplist) if i not in dups]
def get_triples(fi):
all_triples = []
curr = []
with open(fi) as f:
for line in f:
if line.isspace():
all_triples.append(dedup_triples(curr))
curr = []
else:
pieces = line.strip().split('|')
curr.append(tuple(pieces))
if len(curr) > 0:
all_triples.append(dedup_triples(curr))
return all_triples
def calc_precrec(goldfi, predfi):
gold_triples = get_triples(goldfi)
pred_triples = get_triples(predfi)
total_tp, total_predicted, total_gold = 0, 0, 0
assert len(gold_triples) == len(pred_triples)
for i, triplist in enumerate(pred_triples):
tp = sum((1 for j in range(len(triplist))
if any(trip_match(triplist[j], gold_triples[i][k])
for k in range(len(gold_triples[i])))))
total_tp += tp
total_predicted += len(triplist)
total_gold += len(gold_triples[i])
avg_prec = float(total_tp) / total_predicted
avg_rec = float(total_tp) / total_gold
print("totals:", total_tp, total_predicted, total_gold)
print("prec:", avg_prec, "rec:", avg_rec)
return avg_prec, avg_rec
def norm_dld(l1, l2):
ascii_start = 0
# make a string for l1
# all triples are unique...
s1 = ''.join((chr(ascii_start + i) for i in range(len(l1))))
s2 = ''
next_char = ascii_start + len(s1)
for j in range(len(l2)):
found = None
# next_char = chr(ascii_start+len(s1)+j)
for k in range(len(l1)):
if trip_match(l2[j], l1[k]):
found = s1[k]
# next_char = s1[k]
break
if found is None:
s2 += chr(next_char)
next_char += 1
assert next_char <= 128
else:
s2 += found
# return 1- , since this thing gives 0 to perfect matches etc
return 1.0 - normalized_damerau_levenshtein_distance(s1, s2)
def calc_dld(goldfi, predfi):
gold_triples = get_triples(goldfi)
pred_triples = get_triples(predfi)
assert len(gold_triples) == len(pred_triples)
total_score = 0
for i, triplist in enumerate(pred_triples):
total_score += norm_dld(triplist, gold_triples[i])
avg_score = float(total_score) / len(pred_triples)
print("avg score:", avg_score)
return avg_score
calc_precrec(sys.argv[1], sys.argv[2])
calc_dld(sys.argv[1], sys.argv[2])