-
Notifications
You must be signed in to change notification settings - Fork 611
/
gcsan.py
261 lines (221 loc) · 10.4 KB
/
gcsan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# -*- coding: utf-8 -*-
# @Time : 2020/10/4 16:55
# @Author : Yujie Lu
# @Email : yujielu1998@gmail.com
r"""
GCSAN
################################################
Reference:
Chengfeng Xu et al. "Graph Contextualized Self-Attention Network for Session-based Recommendation." in IJCAI 2019.
"""
import math
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import TransformerEncoder
from recbole.model.loss import EmbLoss, BPRLoss
class GNN(nn.Module):
r"""Graph neural networks are well-suited for session-based recommendation,
because it can automatically extract features of session graphs with considerations of rich node connections.
"""
def __init__(self, embedding_size, step=1):
super(GNN, self).__init__()
self.step = step
self.embedding_size = embedding_size
self.input_size = embedding_size * 2
self.gate_size = embedding_size * 3
self.w_ih = Parameter(torch.Tensor(self.gate_size, self.input_size))
self.w_hh = Parameter(torch.Tensor(self.gate_size, self.embedding_size))
self.b_ih = Parameter(torch.Tensor(self.gate_size))
self.b_hh = Parameter(torch.Tensor(self.gate_size))
self.linear_edge_in = nn.Linear(
self.embedding_size, self.embedding_size, bias=True
)
self.linear_edge_out = nn.Linear(
self.embedding_size, self.embedding_size, bias=True
)
# parameters initialization
self._reset_parameters()
def _reset_parameters(self):
stdv = 1.0 / math.sqrt(self.embedding_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def GNNCell(self, A, hidden):
r"""Obtain latent vectors of nodes via gated graph neural network.
Args:
A (torch.FloatTensor): The connection matrix,shape of [batch_size, max_session_len, 2 * max_session_len]
hidden (torch.FloatTensor): The item node embedding matrix, shape of
[batch_size, max_session_len, embedding_size]
Returns:
torch.FloatTensor: Latent vectors of nodes,shape of [batch_size, max_session_len, embedding_size]
"""
input_in = torch.matmul(A[:, :, : A.size(1)], self.linear_edge_in(hidden))
input_out = torch.matmul(
A[:, :, A.size(1) : 2 * A.size(1)], self.linear_edge_out(hidden)
)
# [batch_size, max_session_len, embedding_size * 2]
inputs = torch.cat([input_in, input_out], 2)
# gi.size equals to gh.size, shape of [batch_size, max_session_len, embedding_size * 3]
gi = F.linear(inputs, self.w_ih, self.b_ih)
gh = F.linear(hidden, self.w_hh, self.b_hh)
# (batch_size, max_session_len, embedding_size)
i_r, i_i, i_n = gi.chunk(3, 2)
h_r, h_i, h_n = gh.chunk(3, 2)
reset_gate = torch.sigmoid(i_r + h_r)
input_gate = torch.sigmoid(i_i + h_i)
new_gate = torch.tanh(i_n + reset_gate * h_n)
hy = (1 - input_gate) * hidden + input_gate * new_gate
return hy
def forward(self, A, hidden):
for i in range(self.step):
hidden = self.GNNCell(A, hidden)
return hidden
class GCSAN(SequentialRecommender):
r"""GCSAN captures rich local dependencies via graph neural network,
and learns long-range dependencies by applying the self-attention mechanism.
Note:
In the original paper, the attention mechanism in the self-attention layer is a single head,
for the reusability of the project code, we use a unified transformer component.
According to the experimental results, we only applied regularization to embedding.
"""
def __init__(self, config, dataset):
super(GCSAN, self).__init__(config, dataset)
# load parameters info
self.n_layers = config["n_layers"]
self.n_heads = config["n_heads"]
self.hidden_size = config["hidden_size"] # same as embedding_size
self.inner_size = config[
"inner_size"
] # the dimensionality in feed-forward layer
self.hidden_dropout_prob = config["hidden_dropout_prob"]
self.attn_dropout_prob = config["attn_dropout_prob"]
self.hidden_act = config["hidden_act"]
self.layer_norm_eps = config["layer_norm_eps"]
self.step = config["step"]
self.device = config["device"]
self.weight = config["weight"]
self.reg_weight = config["reg_weight"]
self.loss_type = config["loss_type"]
self.initializer_range = config["initializer_range"]
# define layers and loss
self.item_embedding = nn.Embedding(
self.n_items, self.hidden_size, padding_idx=0
)
self.gnn = GNN(self.hidden_size, self.step)
self.self_attention = TransformerEncoder(
n_layers=self.n_layers,
n_heads=self.n_heads,
hidden_size=self.hidden_size,
inner_size=self.inner_size,
hidden_dropout_prob=self.hidden_dropout_prob,
attn_dropout_prob=self.attn_dropout_prob,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps,
)
self.reg_loss = EmbLoss()
if self.loss_type == "BPR":
self.loss_fct = BPRLoss()
elif self.loss_type == "CE":
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _get_slice(self, item_seq):
items, n_node, A, alias_inputs = [], [], [], []
max_n_node = item_seq.size(1)
item_seq = item_seq.cpu().numpy()
for u_input in item_seq:
node = np.unique(u_input)
items.append(node.tolist() + (max_n_node - len(node)) * [0])
u_A = np.zeros((max_n_node, max_n_node))
for i in np.arange(len(u_input) - 1):
if u_input[i + 1] == 0:
break
u = np.where(node == u_input[i])[0][0]
v = np.where(node == u_input[i + 1])[0][0]
u_A[u][v] = 1
u_sum_in = np.sum(u_A, 0)
u_sum_in[np.where(u_sum_in == 0)] = 1
u_A_in = np.divide(u_A, u_sum_in)
u_sum_out = np.sum(u_A, 1)
u_sum_out[np.where(u_sum_out == 0)] = 1
u_A_out = np.divide(u_A.transpose(), u_sum_out)
u_A = np.concatenate([u_A_in, u_A_out]).transpose()
A.append(u_A)
alias_inputs.append([np.where(node == i)[0][0] for i in u_input])
# The relative coordinates of the item node, shape of [batch_size, max_session_len]
alias_inputs = torch.LongTensor(alias_inputs).to(self.device)
# The connecting matrix, shape of [batch_size, max_session_len, 2 * max_session_len]
A = torch.FloatTensor(np.array(A)).to(self.device)
# The unique item nodes, shape of [batch_size, max_session_len]
items = torch.LongTensor(items).to(self.device)
return alias_inputs, A, items
def forward(self, item_seq, item_seq_len):
assert 0 <= self.weight <= 1
alias_inputs, A, items = self._get_slice(item_seq)
hidden = self.item_embedding(items)
hidden = self.gnn(A, hidden)
alias_inputs = alias_inputs.view(-1, alias_inputs.size(1), 1).expand(
-1, -1, self.hidden_size
)
seq_hidden = torch.gather(hidden, dim=1, index=alias_inputs)
# fetch the last hidden state of last timestamp
ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
a = seq_hidden
attention_mask = self.get_attention_mask(item_seq)
outputs = self.self_attention(a, attention_mask, output_all_encoded_layers=True)
output = outputs[-1]
at = self.gather_indexes(output, item_seq_len - 1)
seq_output = self.weight * at + (1 - self.weight) * ht
return seq_output
def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == "BPR":
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
reg_loss = self.reg_loss(self.item_embedding.weight)
total_loss = loss + self.reg_weight * reg_loss
return total_loss
def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(
seq_output, test_items_emb.transpose(0, 1)
) # [B, n_items]
return scores