-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTransH.py
119 lines (107 loc) · 4.05 KB
/
TransH.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import torch.nn.functional as F
from Model import Model
class TransH(Model):
def __init__(self, ent_tot, rel_tot, dim=100, p_norm=1, norm_flag=True, margin=None, epsilon=None):
super(TransH, self).__init__(ent_tot, rel_tot)
self.dim = dim
self.margin = margin
self.epsilon = epsilon
self.norm_flag = norm_flag
self.p_norm = p_norm
self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)
self.norm_vector = nn.Embedding(self.rel_tot, self.dim)
if margin == None or epsilon == None:
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
nn.init.xavier_uniform_(self.norm_vector.weight.data)
else:
self.embedding_range = nn.Parameter(
torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False
)
nn.init.uniform_(
tensor=self.ent_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
nn.init.uniform_(
tensor=self.rel_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
nn.init.uniform_(
tensor=self.norm_vector.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
if margin != None:
self.margin = nn.Parameter(torch.Tensor([margin]))
self.margin.requires_grad = False
self.margin_flag = True
else:
self.margin_flag = False
def _calc(self, h, t, r, mode):
if self.norm_flag:
h = F.normalize(h, 2, -1)
r = F.normalize(r, 2, -1)
t = F.normalize(t, 2, -1)
if mode != 'normal':
h = h.view(-1, r.shape[0], h.shape[-1])
t = t.view(-1, r.shape[0], t.shape[-1])
r = r.view(-1, r.shape[0], r.shape[-1])
if mode == 'head_batch':
score = h + (r - t)
else:
score = (h + r) - t
score = torch.norm(score, self.p_norm, -1).flatten()
return score
def _transfer(self, e, norm):
norm = F.normalize(norm, p=2, dim=-1)
if e.shape[0] != norm.shape[0]:
e = e.view(-1, norm.shape[0], e.shape[-1])
norm = norm.view(-1, norm.shape[0], norm.shape[-1])
e = e - torch.sum(e * norm, -1, True) * norm
return e.view(-1, e.shape[-1])
else:
return e - torch.sum(e * norm, -1, True) * norm
def startingBatch(self):
# Do nothing
return
def forward(self, data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
mode = data['mode']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
r_norm = self.norm_vector(batch_r)
h = self._transfer(h, r_norm)
t = self._transfer(t, r_norm)
score = self._calc(h, t, r, mode)
if self.margin_flag:
return self.margin - score
else:
return score
def regularization(self, data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
r_norm = self.norm_vector(batch_r)
regul = (torch.mean(h ** 2) +
torch.mean(t ** 2) +
torch.mean(r ** 2) +
torch.mean(r_norm ** 2)) / 4
return regul
def predict(self, data):
score = self.forward(data)
if self.margin_flag:
score = self.margin - score
return score.cpu().data.numpy()
else:
return score.cpu().data.numpy()