-
Notifications
You must be signed in to change notification settings - Fork 44
/
aggregator.py
57 lines (44 loc) · 2.32 KB
/
aggregator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn.functional as F
class Aggregator(torch.nn.Module):
'''
Aggregator class
Mode in ['sum', 'concat', 'neighbor']
'''
def __init__(self, batch_size, dim, aggregator):
super(Aggregator, self).__init__()
self.batch_size = batch_size
self.dim = dim
if aggregator == 'concat':
self.weights = torch.nn.Linear(2 * dim, dim, bias=True)
else:
self.weights = torch.nn.Linear(dim, dim, bias=True)
self.aggregator = aggregator
def forward(self, self_vectors, neighbor_vectors, neighbor_relations, user_embeddings, act):
batch_size = user_embeddings.size(0)
if batch_size != self.batch_size:
self.batch_size = batch_size
neighbors_agg = self._mix_neighbor_vectors(neighbor_vectors, neighbor_relations, user_embeddings)
if self.aggregator == 'sum':
output = (self_vectors + neighbors_agg).view((-1, self.dim))
elif self.aggregator == 'concat':
output = torch.cat((self_vectors, neighbors_agg), dim=-1)
output = output.view((-1, 2 * self.dim))
else:
output = neighbors_agg.view((-1, self.dim))
output = self.weights(output)
return act(output.view((self.batch_size, -1, self.dim)))
def _mix_neighbor_vectors(self, neighbor_vectors, neighbor_relations, user_embeddings):
'''
This aims to aggregate neighbor vectors
'''
# [batch_size, 1, dim] -> [batch_size, 1, 1, dim]
user_embeddings = user_embeddings.view((self.batch_size, 1, 1, self.dim))
# [batch_size, -1, n_neighbor, dim] -> [batch_size, -1, n_neighbor]
user_relation_scores = (user_embeddings * neighbor_relations).sum(dim = -1)
user_relation_scores_normalized = F.softmax(user_relation_scores, dim = -1)
# [batch_size, -1, n_neighbor] -> [batch_size, -1, n_neighbor, 1]
user_relation_scores_normalized = user_relation_scores_normalized.unsqueeze(dim = -1)
# [batch_size, -1, n_neighbor, 1] * [batch_size, -1, n_neighbor, dim] -> [batch_size, -1, dim]
neighbors_aggregated = (user_relation_scores_normalized * neighbor_vectors).sum(dim = 2)
return neighbors_aggregated