-
Notifications
You must be signed in to change notification settings - Fork 14
/
rgcn.py
148 lines (116 loc) · 5.6 KB
/
rgcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn.conv import MessagePassing
def uniform(size, tensor):
bound = 1.0 / math.sqrt(size)
if tensor is not None:
tensor.data.uniform_(-bound, bound)
class RGCN(torch.nn.Module):
def __init__(self, num_entities, num_relations, num_bases, dropout):
super(RGCN, self).__init__()
self.entity_embedding = nn.Embedding(num_entities, 100)
self.relation_embedding = nn.Parameter(torch.Tensor(num_relations, 100))
nn.init.xavier_uniform_(self.relation_embedding, gain=nn.init.calculate_gain('relu'))
self.conv1 = RGCNConv(100, 100, num_relations * 2, num_bases=num_bases)
self.conv2 = RGCNConv(100, 100, num_relations * 2, num_bases=num_bases)
self.dropout_ratio = dropout
def forward(self, entity, edge_index, edge_type, edge_norm):
x = self.entity_embedding(entity)
x = self.conv1(x, edge_index, edge_type, edge_norm)
x = F.relu(self.conv1(x, edge_index, edge_type, edge_norm))
x = F.dropout(x, p = self.dropout_ratio, training = self.training)
x = self.conv2(x, edge_index, edge_type, edge_norm)
return x
def distmult(self, embedding, triplets):
s = embedding[triplets[:,0]]
r = self.relation_embedding[triplets[:,1]]
o = embedding[triplets[:,2]]
score = torch.sum(s * r * o, dim=1)
return score
def score_loss(self, embedding, triplets, target):
score = self.distmult(embedding, triplets)
return score, F.binary_cross_entropy_with_logits(score, target)
def reg_loss(self, embedding):
return torch.mean(embedding.pow(2)) + torch.mean(self.relation_embedding.pow(2))
class RGCNConv(MessagePassing):
r"""The relational graph convolutional operator from the `"Modeling
Relational Data with Graph Convolutional Networks"
<https://arxiv.org/abs/1703.06103>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
\mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
\frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,
where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
stores a relation identifier
:math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
num_relations (int): Number of relations.
num_bases (int): Number of bases used for basis-decomposition.
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add transformed root node features to the output.
(default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels, out_channels, num_relations, num_bases,
root_weight=True, bias=True, **kwargs):
super(RGCNConv, self).__init__(aggr='mean', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_relations = num_relations
self.num_bases = num_bases
self.basis = nn.Parameter(torch.Tensor(num_bases, in_channels, out_channels))
self.att = nn.Parameter(torch.Tensor(num_relations, num_bases))
if root_weight:
self.root = nn.Parameter(torch.Tensor(in_channels, out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
size = self.num_bases * self.in_channels
uniform(size, self.basis)
uniform(size, self.att)
uniform(size, self.root)
uniform(size, self.bias)
def forward(self, x, edge_index, edge_type, edge_norm=None, size=None):
""""""
return self.propagate(edge_index, size=size, x=x, edge_type=edge_type,
edge_norm=edge_norm)
def message(self, x_j, edge_index_j, edge_type, edge_norm):
w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
# If no node features are given, we implement a simple embedding
# loopkup based on the target node index and its edge type.
if x_j is None:
w = w.view(-1, self.out_channels)
index = edge_type * self.in_channels + edge_index_j
out = torch.index_select(w, 0, index)
else:
w = w.view(self.num_relations, self.in_channels, self.out_channels)
w = torch.index_select(w, 0, edge_type)
out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
return out if edge_norm is None else out * edge_norm.view(-1, 1)
def update(self, aggr_out, x):
if self.root is not None:
if x is None:
out = aggr_out + self.root
else:
out = aggr_out + torch.matmul(x, self.root)
if self.bias is not None:
out = out + self.bias
return out
def __repr__(self):
return '{}({}, {}, num_relations={})'.format(
self.__class__.__name__, self.in_channels, self.out_channels,
self.num_relations)