-
Notifications
You must be signed in to change notification settings - Fork 406
/
Copy pathMHCN.py
240 lines (225 loc) · 12.7 KB
/
MHCN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from base.graphRecommender import GraphRecommender
from base.socialRecommender import SocialRecommender
import tensorflow as tf
from scipy.sparse import coo_matrix
import numpy as np
from util.loss import bpr_loss
import os
from util import config
from math import sqrt
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Recommended Maximum Epoch Setting: LastFM 120 Douban 30 Yelp 30
# A slight performance drop is observed when we transplanted the model from python2 to python3. The cause is unclear.
class MHCN(SocialRecommender,GraphRecommender):
def __init__(self, conf, trainingSet=None, testSet=None, relation=None, fold='[1]'):
GraphRecommender.__init__(self, conf=conf, trainingSet=trainingSet, testSet=testSet, fold=fold)
SocialRecommender.__init__(self, conf=conf, trainingSet=trainingSet, testSet=testSet, relation=relation,fold=fold)
def readConfiguration(self):
super(MHCN, self).readConfiguration()
args = config.OptionConf(self.config['MHCN'])
self.n_layers = int(args['-n_layer'])
self.ss_rate = float(args['-ss_rate'])
def buildSparseRelationMatrix(self):
row, col, entries = [], [], []
for pair in self.social.relation:
# symmetric matrix
row += [self.data.user[pair[0]]]
col += [self.data.user[pair[1]]]
entries += [1.0]
AdjacencyMatrix = coo_matrix((entries, (row, col)), shape=(self.num_users,self.num_users),dtype=np.float32)
return AdjacencyMatrix
def buildSparseRatingMatrix(self):
row, col, entries = [], [], []
for pair in self.data.trainingData:
# symmetric matrix
row += [self.data.user[pair[0]]]
col += [self.data.item[pair[1]]]
entries += [1.0]
ratingMatrix = coo_matrix((entries, (row, col)), shape=(self.num_users,self.num_items),dtype=np.float32)
return ratingMatrix
def buildJointAdjacency(self):
indices = [[self.data.user[item[0]], self.data.item[item[1]]] for item in self.data.trainingData]
values = [float(item[2]) / sqrt(len(self.data.trainSet_u[item[0]])) / sqrt(len(self.data.trainSet_i[item[1]]))
for item in self.data.trainingData]
norm_adj = tf.SparseTensor(indices=indices, values=values,
dense_shape=[self.num_users, self.num_items])
return norm_adj
def buildMotifInducedAdjacencyMatrix(self):
S = self.buildSparseRelationMatrix()
Y = self.buildSparseRatingMatrix()
self.userAdjacency = Y.tocsr()
self.itemAdjacency = Y.T.tocsr()
B = S.multiply(S.T)
U = S - B
C1 = (U.dot(U)).multiply(U.T)
A1 = C1 + C1.T
C2 = (B.dot(U)).multiply(U.T) + (U.dot(B)).multiply(U.T) + (U.dot(U)).multiply(B)
A2 = C2 + C2.T
C3 = (B.dot(B)).multiply(U) + (B.dot(U)).multiply(B) + (U.dot(B)).multiply(B)
A3 = C3 + C3.T
A4 = (B.dot(B)).multiply(B)
C5 = (U.dot(U)).multiply(U) + (U.dot(U.T)).multiply(U) + (U.T.dot(U)).multiply(U)
A5 = C5 + C5.T
A6 = (U.dot(B)).multiply(U) + (B.dot(U.T)).multiply(U.T) + (U.T.dot(U)).multiply(B)
A7 = (U.T.dot(B)).multiply(U.T) + (B.dot(U)).multiply(U) + (U.dot(U.T)).multiply(B)
A8 = (Y.dot(Y.T)).multiply(B)
A9 = (Y.dot(Y.T)).multiply(U)
A9 = A9+A9.T
A10 = Y.dot(Y.T)-A8-A9
#addition and row-normalization
H_s = sum([A1,A2,A3,A4,A5,A6,A7])
H_s = H_s.multiply(1.0/H_s.sum(axis=1).reshape(-1, 1))
H_j = sum([A8,A9])
H_j = H_j.multiply(1.0/H_j.sum(axis=1).reshape(-1, 1))
H_p = A10
H_p = H_p.multiply(H_p>1)
H_p = H_p.multiply(1.0/H_p.sum(axis=1).reshape(-1, 1))
return [H_s,H_j,H_p]
def adj_to_sparse_tensor(self,adj):
adj = adj.tocoo()
indices = np.mat(list(zip(adj.row, adj.col)))
adj = tf.SparseTensor(indices, adj.data.astype(np.float32), adj.shape)
return adj
def initModel(self):
super(MHCN, self).initModel()
M_matrices = self.buildMotifInducedAdjacencyMatrix()
self.weights = {}
initializer = tf.contrib.layers.xavier_initializer()
self.n_channel = 4
self.neg_idx = tf.placeholder(tf.int32, name="neg_holder")
#define learnable paramters
for i in range(self.n_channel):
self.weights['gating%d' % (i+1)] = tf.Variable(initializer([self.emb_size, self.emb_size]), name='g_W_%d_1' % (i + 1))
self.weights['gating_bias%d' %(i+1)] = tf.Variable(initializer([1, self.emb_size]), name='g_W_b_%d_1' % (i + 1))
self.weights['sgating%d' % (i + 1)] = tf.Variable(initializer([self.emb_size, self.emb_size]), name='sg_W_%d_1' % (i + 1))
self.weights['sgating_bias%d' % (i + 1)] = tf.Variable(initializer([1, self.emb_size]), name='sg_W_b_%d_1' % (i + 1))
self.weights['attention'] = tf.Variable(initializer([1, self.emb_size]), name='at')
self.weights['attention_mat'] = tf.Variable(initializer([self.emb_size, self.emb_size]), name='atm')
#define inline functions
def self_gating(em,channel):
return tf.multiply(em,tf.nn.sigmoid(tf.matmul(em,self.weights['gating%d' % channel])+self.weights['gating_bias%d' %channel]))
def self_supervised_gating(em, channel):
return tf.multiply(em,tf.nn.sigmoid(tf.matmul(em, self.weights['sgating%d' % channel])+self.weights['sgating_bias%d' % channel]))
def channel_attention(*channel_embeddings):
weights = []
for embedding in channel_embeddings:
weights.append(tf.reduce_sum(tf.multiply(self.weights['attention'], tf.matmul(embedding, self.weights['attention_mat'])),1))
score = tf.nn.softmax(tf.transpose(weights))
mixed_embeddings = 0
for i in range(len(weights)):
mixed_embeddings += tf.transpose(tf.multiply(tf.transpose(score)[i], tf.transpose(channel_embeddings[i])))
return mixed_embeddings,score
#initialize adjacency matrices
H_s = M_matrices[0]
H_s = self.adj_to_sparse_tensor(H_s)
H_j = M_matrices[1]
H_j = self.adj_to_sparse_tensor(H_j)
H_p = M_matrices[2]
H_p = self.adj_to_sparse_tensor(H_p)
R = self.buildJointAdjacency()
#self-gating
user_embeddings_c1 = self_gating(self.user_embeddings,1)
user_embeddings_c2 = self_gating(self.user_embeddings, 2)
user_embeddings_c3 = self_gating(self.user_embeddings, 3)
simple_user_embeddings = self_gating(self.user_embeddings,4)
all_embeddings_c1 = [user_embeddings_c1]
all_embeddings_c2 = [user_embeddings_c2]
all_embeddings_c3 = [user_embeddings_c3]
all_embeddings_simple = [simple_user_embeddings]
item_embeddings = self.item_embeddings
all_embeddings_i = [item_embeddings]
self.ss_loss = 0
#multi-channel convolution
for k in range(self.n_layers):
mixed_embedding = channel_attention(user_embeddings_c1, user_embeddings_c2, user_embeddings_c3)[0] + simple_user_embeddings / 2
#Channel S
user_embeddings_c1 = tf.sparse_tensor_dense_matmul(H_s,user_embeddings_c1)
norm_embeddings = tf.math.l2_normalize(user_embeddings_c1, axis=1)
all_embeddings_c1 += [norm_embeddings]
#Channel J
user_embeddings_c2 = tf.sparse_tensor_dense_matmul(H_j, user_embeddings_c2)
norm_embeddings = tf.math.l2_normalize(user_embeddings_c2, axis=1)
all_embeddings_c2 += [norm_embeddings]
#Channel P
user_embeddings_c3 = tf.sparse_tensor_dense_matmul(H_p, user_embeddings_c3)
norm_embeddings = tf.math.l2_normalize(user_embeddings_c3, axis=1)
all_embeddings_c3 += [norm_embeddings]
# item convolution
new_item_embeddings = tf.sparse_tensor_dense_matmul(tf.sparse.transpose(R), mixed_embedding)
norm_embeddings = tf.math.l2_normalize(new_item_embeddings, axis=1)
all_embeddings_i += [norm_embeddings]
simple_user_embeddings = tf.sparse_tensor_dense_matmul(R, item_embeddings)
all_embeddings_simple += [tf.math.l2_normalize(simple_user_embeddings, axis=1)]
item_embeddings = new_item_embeddings
#averaging the channel-specific embeddings
user_embeddings_c1 = tf.reduce_sum(all_embeddings_c1, axis=0)
user_embeddings_c2 = tf.reduce_sum(all_embeddings_c2, axis=0)
user_embeddings_c3 = tf.reduce_sum(all_embeddings_c3, axis=0)
simple_user_embeddings = tf.reduce_sum(all_embeddings_simple, axis=0)
item_embeddings = tf.reduce_sum(all_embeddings_i, axis=0)
#aggregating channel-specific embeddings
self.final_item_embeddings = item_embeddings
self.final_user_embeddings,self.attention_score = channel_attention(user_embeddings_c1,user_embeddings_c2,user_embeddings_c3)
self.final_user_embeddings += simple_user_embeddings/2
#create self-supervised loss
self.ss_loss += self.hierarchical_self_supervision(self_supervised_gating(self.final_user_embeddings,1), H_s)
self.ss_loss += self.hierarchical_self_supervision(self_supervised_gating(self.final_user_embeddings,2), H_j)
self.ss_loss += self.hierarchical_self_supervision(self_supervised_gating(self.final_user_embeddings,3), H_p)
#embedding look-up
self.batch_neg_item_emb = tf.nn.embedding_lookup(self.final_item_embeddings, self.neg_idx)
self.batch_user_emb = tf.nn.embedding_lookup(self.final_user_embeddings, self.u_idx)
self.batch_pos_item_emb = tf.nn.embedding_lookup(self.final_item_embeddings, self.v_idx)
def hierarchical_self_supervision(self,em,adj):
def row_shuffle(embedding):
return tf.gather(embedding, tf.random.shuffle(tf.range(tf.shape(embedding)[0])))
def row_column_shuffle(embedding):
corrupted_embedding = tf.transpose(tf.gather(tf.transpose(embedding), tf.random.shuffle(tf.range(tf.shape(tf.transpose(embedding))[0]))))
corrupted_embedding = tf.gather(corrupted_embedding, tf.random.shuffle(tf.range(tf.shape(corrupted_embedding)[0])))
return corrupted_embedding
def score(x1,x2):
return tf.reduce_sum(tf.multiply(x1,x2),1)
user_embeddings = em
# user_embeddings = tf.math.l2_normalize(em,1) #For Douban, normalization is needed.
edge_embeddings = tf.sparse_tensor_dense_matmul(adj,user_embeddings)
#Local MIM
pos = score(user_embeddings,edge_embeddings)
neg1 = score(row_shuffle(user_embeddings),edge_embeddings)
neg2 = score(row_column_shuffle(edge_embeddings),user_embeddings)
local_loss = tf.reduce_sum(-tf.log(tf.sigmoid(pos-neg1))-tf.log(tf.sigmoid(neg1-neg2)))
#Global MIM
graph = tf.reduce_mean(edge_embeddings,0)
pos = score(edge_embeddings,graph)
neg1 = score(row_column_shuffle(edge_embeddings),graph)
global_loss = tf.reduce_sum(-tf.log(tf.sigmoid(pos-neg1)))
return global_loss+local_loss
def trainModel(self):
rec_loss = bpr_loss(self.batch_user_emb, self.batch_pos_item_emb, self.batch_neg_item_emb)
reg_loss = 0
for key in self.weights:
reg_loss += 0.001*tf.nn.l2_loss(self.weights[key])
reg_loss += self.regU * (tf.nn.l2_loss(self.user_embeddings) + tf.nn.l2_loss(self.item_embeddings))
total_loss = rec_loss+reg_loss + self.ss_rate*self.ss_loss
opt = tf.train.AdamOptimizer(self.lRate)
train_op = opt.minimize(total_loss)
init = tf.global_variables_initializer()
self.sess.run(init)
# Suggested Maximum epoch Setting: LastFM 120 Douban 30 Yelp 30
for epoch in range(self.maxEpoch):
for n, batch in enumerate(self.next_batch_pairwise()):
user_idx, i_idx, j_idx = batch
_, l1 = self.sess.run([train_op, rec_loss],
feed_dict={self.u_idx: user_idx, self.neg_idx: j_idx, self.v_idx: i_idx})
print(self.foldInfo,'training:', epoch + 1, 'batch', n, 'rec loss:', l1)#,'ss_loss',l2
self.U, self.V = self.sess.run([self.final_user_embeddings, self.final_item_embeddings])
self.ranking_performance(epoch)
#self.U, self.V = self.sess.run([self.main_user_embeddings, self.main_item_embeddings])
self.U,self.V = self.bestU,self.bestV
def saveModel(self):
self.bestU, self.bestV = self.sess.run([self.final_user_embeddings, self.final_item_embeddings])
def predictForRanking(self, u):
'invoked to rank all the items for the user'
if self.data.containsUser(u):
u = self.data.getUserId(u)
return self.V.dot(self.U[u])
else:
return [self.data.globalMean] * self.num_items