-
Notifications
You must be signed in to change notification settings - Fork 0
/
dice_alignment.py
123 lines (112 loc) · 3.98 KB
/
dice_alignment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import math
import numpy as np
from tqdm import tqdm
import torch
from collections import defaultdict
import random
# from scipy.optimize import linear_sum_assignment
# A Dice-coefficient based aligner, roughly based on:
# https://www.aclweb.org/anthology/P97-1063.pdf
# I Dan Melamud. "A Word-to-Word Model of Translational Equivalence". ACL 1997
parser = argparse.ArgumentParser(description="Parse arguments")
parser.add_argument("srctrg_file", help="Source file")
parser.add_argument("--thresh", default=-1.0e5, help="The threshold")
parser.add_argument(
"--eqn", default="dice", help="What type of equation to use (dice/pmi)"
)
parser.add_argument(
"--sub_concrete",
action="store_true",
help="whether to subtract scores with concreteness difference",
)
args = parser.parse_args()
srctrg_cnt = defaultdict(lambda: 0)
src_cnt = defaultdict(lambda: 0)
trg_cnt = defaultdict(lambda: 0)
tot_cnt = 0
with open(args.srctrg_file, "r") as fsrctrg:
for lsrctrg in tqdm(fsrctrg):
lsrctrg = lsrctrg.strip()
if (not lsrctrg) or (len(lsrctrg.split(" ||| ")) == 1):
continue
tot_cnt += 1.0
# print(lsrctrg)
lsrc, ltrg = lsrctrg.split(" ||| ")
wsrc = set(lsrc.lower().strip().split())
wtrg = set(ltrg.lower().strip().split())
for s in wsrc:
src_cnt[s] += 1.0
for t in wtrg:
srctrg_cnt[s, t] += 1.0
for t in wtrg:
trg_cnt[t] += 1.0
if args.eqn == "dice":
srctrg_score = {
(s, t): 2 * v / (src_cnt[s] + trg_cnt[t])
for ((s, t), v) in tqdm(srctrg_cnt.items())
}
elif args.eqn == "pmi":
srctrg_score = {
(s, t): math.log(v * tot_cnt / src_cnt[s] / trg_cnt[t])
for ((s, t), v) in tqdm(srctrg_cnt.items())
}
else:
raise ValueError(f"Illegal equation {args.eqn}")
# =============LOAD CONCRETE=================
concrete_file = "./data/concreteness_modified.txt"
concretes = open(concrete_file, "r").readlines()
con = {}
for line in concretes:
line = line.rstrip().split("\t")
word = "-".join(line[0].split(" "))
score = line[2]
con[word] = score
# =============END LOAD CONCRETE=================
with open(args.srctrg_file, "r") as fsrctrg:
for lsrctrg in tqdm(fsrctrg):
lsrctrg = lsrctrg.strip()
if (not lsrctrg) or (len(lsrctrg.split(" ||| ")) == 1):
print()
continue
# print(lsrctrg)
lsrc, ltrg = lsrctrg.split(" ||| ")
wsrc = list(set(lsrc.lower().strip().split()))
if "_" in wsrc:
wsrc.remove("_") # remove placeholder
wtrg = list(set(ltrg.lower().strip().split()))
dsrctrg = np.zeros((len(wsrc), len(wtrg)))
# concrete matrix
csrc = torch.tensor(
[
float(con[tok]) / 5.0 if tok in con.keys() else random.uniform(0, 1)
for tok in wsrc
]
).unsqueeze(-1)
ctrg = torch.tensor(
[
float(con[tok]) / 5.0 if tok in con.keys() else random.uniform(0, 1)
for tok in wtrg
]
).unsqueeze(-1)
M = (csrc - ctrg.T) ** 2
for i, s in enumerate(wsrc):
for j, t in enumerate(wtrg):
if args.sub_concrete:
dsrctrg[i, j] = srctrg_score[s, t] - M[i, j]
else:
dsrctrg[i, j] = srctrg_score[s, t]
assert M.shape == dsrctrg.shape
# print(wsrc)
# print(wtrg)
# print(dsrctrg)
# print(np.argmax(dsrctrg))
idsrc, idtrg = np.unravel_index(np.argmax(dsrctrg), dsrctrg.shape)
# print(f'{idsrc}, {idtrg}')
aligns = []
while dsrctrg[idsrc, idtrg] > args.thresh:
aligns.append(f"{wsrc[idsrc]}:{wtrg[idtrg]}:{dsrctrg[idsrc,idtrg]:.3f}")
dsrctrg[idsrc, :] = args.thresh
dsrctrg[:, idtrg] = args.thresh
idsrc, idtrg = np.unravel_index(np.argmax(dsrctrg), dsrctrg.shape)
print(" ".join(aligns))