-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
84 lines (72 loc) · 2.66 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from trainer import *
from params import *
from data_loader import *
import json
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.abspath(__file__))
)
sys.path.append(BASE_DIR)
if __name__ == "__main__":
params = get_params()
print("---------Parameters---------")
for k, v in params.items():
print(k + ": " + str(v))
print("----------------------------")
# control random seed
if params["seed"] is not None:
SEED = params["seed"]
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
np.random.seed(SEED)
random.seed(SEED)
# select the dataset
for k, v in data_dir.items():
data_dir[k] = params["data_path"] + v
tail = ""
if params["data_form"] == "In-Train":
tail = "_in_train"
dataset = dict()
print("loading train_tasks{} ... ...".format(tail))
dataset["train_tasks"] = json.load(open(data_dir["train_tasks" + tail]))
print("loading test_tasks ... ...")
dataset["test_tasks"] = json.load(open(data_dir["test_tasks"]))
print("loading dev_tasks ... ...")
dataset["dev_tasks"] = json.load(open(data_dir["dev_tasks"]))
print("loading rel2candidates{} ... ...".format(tail))
dataset["rel2candidates"] = json.load(
open(data_dir["rel2candidates" + tail])
)
print("loading e1rel_e2{} ... ...".format(tail))
dataset["e1rel_e2"] = json.load(open(data_dir["e1rel_e2" + tail]))
print("loading ent2id ... ...")
dataset["ent2id"] = json.load(open(data_dir["ent2ids"]))
if params["data_form"] in ["Pre-Train","In-Train"]:
print("loading embedding ... ...")
dataset["ent2emb"] = np.load(data_dir["ent2vec"])
print("----------------------------")
# data_loader
train_data_loader = DataLoader(dataset, params, step="train")
dev_data_loader = DataLoader(dataset, params, step="dev")
test_data_loader = DataLoader(dataset, params, step="test")
data_loaders = [train_data_loader, dev_data_loader, test_data_loader]
# trainer
trainer = Trainer(data_loaders, dataset, params)
if params["step"] == "train":
trainer.train()
print("test")
print(params["prefix"])
trainer.reload()
trainer.eval(istest=True)
elif params["step"] == "test":
print(params["prefix"])
if params["eval_by_rel"]:
trainer.eval_by_relation(istest=True)
else:
trainer.eval(istest=True)
elif params["step"] == "dev":
print(params["prefix"])
if params["eval_by_rel"]:
trainer.eval_by_relation(istest=False)
else:
trainer.eval(istest=False)