-
Notifications
You must be signed in to change notification settings - Fork 24
/
model_GCMC.py
82 lines (71 loc) · 4.57 KB
/
model_GCMC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
## baseline: Graph convolutional matrix completion (GCMC)
## Rianne Van Den Berg, Thomas N. Kipf, and Max Welling. Graph convolutional matrix completion. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, KDD '18, 2018.
import tensorflow as tf
import numpy as np
class model_GCMC(object):
def __init__(self, layer, n_users, n_items, emb_dim, lr, lamda, pre_train_latent_factor, if_pretrain, sparse_graph):
self.model_name = 'GCMC'
self.n_users = n_users
self.n_items = n_items
self.emb_dim = emb_dim
self.layer = layer
self.lamda = lamda
self.lr = lr
[self.U, self.V] = pre_train_latent_factor
self.if_pretrain = if_pretrain
self.A_hat = sparse_graph
# placeholder definition
self.users = tf.placeholder(tf.int32, shape=(None,))
self.pos_items = tf.placeholder(tf.int32, shape=(None,))
self.neg_items = tf.placeholder(tf.int32, shape=(None,))
self.keep_prob = tf.placeholder(tf.float32, shape=(None))
self.items_in_train_data = tf.placeholder(tf.float32, shape=(None, None))
self.top_k = tf.placeholder(tf.int32, shape=(None))
if self.if_pretrain:
self.user_embeddings = tf.Variable(self.U, name='user_embeddings')
self.item_embeddings = tf.Variable(self.V, name='item_embeddings')
else:
self.user_embeddings = tf.Variable(
tf.random_normal([self.n_users, self.emb_dim], mean=0.01, stddev=0.02, dtype=tf.float32),
name='user_embeddings')
self.item_embeddings = tf.Variable(
tf.random_normal([self.n_items, self.emb_dim], mean=0.01, stddev=0.02, dtype=tf.float32),
name='item_embeddings')
self.filters = []
for l in range(self.layer):
self.filters.append(
tf.Variable((np.random.normal(0, 0.001, (self.emb_dim, self.emb_dim)) + np.diag(np.random.normal(1, 0.001, self.emb_dim))).astype(np.float32))
)
embeddings = tf.concat([self.user_embeddings, self.item_embeddings], axis=0)
all_embeddings = [embeddings]
for l in range(self.layer):
embeddings = tf.sparse_tensor_dense_matmul(self.A_hat, embeddings)
embeddings = tf.nn.sigmoid(tf.matmul(embeddings, self.filters[l]))
all_embeddings += [embeddings]
all_embeddings = tf.concat(all_embeddings, 1)
self.user_all_embeddings, self.item_all_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0)
self.u_embeddings = tf.nn.embedding_lookup(self.user_all_embeddings, self.users)
self.pos_i_embeddings = tf.nn.embedding_lookup(self.item_all_embeddings, self.pos_items)
self.neg_i_embeddings = tf.nn.embedding_lookup(self.item_all_embeddings, self.neg_items)
self.u_embeddings_reg = tf.nn.embedding_lookup(self.user_embeddings, self.users)
self.pos_i_embeddings_reg = tf.nn.embedding_lookup(self.item_embeddings, self.pos_items)
self.neg_i_embeddings_reg = tf.nn.embedding_lookup(self.item_embeddings, self.neg_items)
self.loss = self.create_bpr_loss(self.u_embeddings, self.pos_i_embeddings, self.neg_i_embeddings) + \
self.lamda * self.regularization(self.u_embeddings_reg, self.pos_i_embeddings_reg,
self.neg_i_embeddings_reg, self.filters)
self.opt = tf.train.AdamOptimizer(learning_rate=self.lr)
self.updates = self.opt.minimize(self.loss, var_list=[self.user_embeddings, self.item_embeddings] + self.filters)
self.all_ratings = tf.matmul(self.u_embeddings, self.item_all_embeddings, transpose_a=False, transpose_b=True)
self.all_ratings += self.items_in_train_data ## set a very small value for the items appearing in the training set to make sure they are at the end of the sorted list
self.top_items = tf.nn.top_k(self.all_ratings, k=self.top_k, sorted=True).indices
def create_bpr_loss(self, users, pos_items, neg_items):
pos_scores = tf.reduce_sum(tf.multiply(users, pos_items), axis=1)
neg_scores = tf.reduce_sum(tf.multiply(users, neg_items), axis=1)
maxi = tf.log(tf.nn.sigmoid(pos_scores - neg_scores))
loss = tf.negative(tf.reduce_sum(maxi))
return loss
def regularization(self, users, pos_items, neg_items, filters):
regularizer = tf.nn.l2_loss(users) + tf.nn.l2_loss(pos_items) + tf.nn.l2_loss(neg_items)
for l in range(self.layer):
regularizer += tf.nn.l2_loss(filters[l])
return regularizer