-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare.py
126 lines (106 loc) · 3.85 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from collections import defaultdict
import numpy as np
import json
import argparse
args = argparse.ArgumentParser()
args.add_argument("-path", "--dataset_path", default="./NELL", type=str)
args.add_argument("-data", "--dataset_name", default="NELL-One", type=str)
params = args.parse_args()
dire = params.dataset_path
data = params.dataset_name
path = {
'train_tasks': '/train_tasks.json',
'test_tasks': '/test_tasks.json',
'dev_tasks': '/dev_tasks.json',
'rel2candidates': '/rel2candidates.json',
'e1rel_e2': '/e1rel_e2.json',
'path_graph': '/path_graph',
'ent2emb': '/entity2vec.ComplEx'
}
print('Start')
print('Process {} in {}'.format(data, dire))
print("Loading jsons ... ...")
train_tasks = json.load(open(dire+path['train_tasks']))
test_tasks = json.load(open(dire+path['test_tasks']))
dev_tasks = json.load(open(dire+path['dev_tasks']))
e1rel_e2 = json.load(open(dire+path['e1rel_e2']))
path_graph_lines = open(dire+path['path_graph']).readlines()
rel2candidates = json.load(open(dire+path['rel2candidates']))
ent2emb = np.loadtxt(dire+path['ent2emb'], dtype=np.float32)
# convert entity2vec to .npy
np.save('ent2vec.npy', ent2emb)
entity = set()
path_graph = []
for line in path_graph_lines:
triple = line.strip().split()
entity.add(triple[0])
entity.add(triple[2])
path_graph.append(triple)
json.dump(path_graph, open(dire+'/path_graph.json', 'w'))
# train_tasks_in_train
print("Writing train_tasks_in_train.json ... ...")
path_graph_tasks = defaultdict(list)
for p in path_graph:
path_graph_tasks[p[1]].append(p)
train_tasks_in_train = {**train_tasks, **path_graph_tasks}
json.dump(train_tasks_in_train, open(dire+'/train_tasks_in_train.json', 'w'))
# rel2candidates_in_train
if data == 'NELL-One':
print("Writing rel2candidates_in_train.json ... ...")
entity_dict = defaultdict(list)
for ent in entity:
s = ent.split(':')
if len(s) != 3:
entity_dict['num'].append(ent)
else:
entity_dict[s[1]].append(ent)
rel2candidates_in_train = defaultdict(list)
for rel, task in path_graph_tasks.items():
types = []
cands = []
for i in task:
e1, r, e2 = i
s = e2.split(':')
if len(s) != 3:
types.append('num')
else:
types.append(s[1])
types = set(types)
for t in types:
cands.extend(entity_dict[t])
cands = list(set(cands))
rel2candidates_in_train[rel] = cands
rel2candidates_in_train = {**rel2candidates, **rel2candidates_in_train}
else:
print("Writing rel2candidates_in_train.json ... ...")
rel2candidates_in_train = defaultdict(list)
for k, v in path_graph_tasks.items():
cands = []
for tri in v:
cands.append(tri[2])
cands = list(set(cands))
rel2candidates_in_train[k] = cands
rel2candidates_in_train = {**rel2candidates, **rel2candidates_in_train}
for rel, cands in rel2candidates_in_train.items():
if len(cands) == 1:
one_cand = cands[0]
for k, v in train_tasks_in_train.items():
for tri in v:
h, r, t = tri
if t == one_cand:
cands.extend(rel2candidates_in_train[r])
break
if len(cands) > 1:
break
rel2candidates_in_train[rel] = list(set(cands))
json.dump(rel2candidates_in_train, open(dire + '/rel2candidates_in_train.json', 'w'))
# e1rel_e2_in_train
print("Writing e1rel_e2_in_train.json ... ...")
e1rel_e2_in_train = defaultdict(list)
for k, v in path_graph_tasks.items():
for triple in v:
e1, r, e2 = triple
e1rel_e2_in_train[e1+r].append(e2)
e1rel_e2_in_train = {**e1rel_e2, **e1rel_e2_in_train}
json.dump(e1rel_e2_in_train, open(dire+'/e1rel_e2_in_train.json', 'w'))
print('End')