-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathXSimGCL.py
101 lines (89 loc) · 4.92 KB
/
XSimGCL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.nn.functional as F
from base.graph_recommender import GraphRecommender
from util.sampler import next_batch_pairwise
from base.torch_interface import TorchGraphInterface
from util.loss_torch import bpr_loss, l2_reg_loss, InfoNCE
# Paper: XSimGCL - Towards Extremely Simple Graph Contrastive Learning for Recommendation
class XSimGCL(GraphRecommender):
def __init__(self, conf, training_set, test_set):
super(XSimGCL, self).__init__(conf, training_set, test_set)
config = self.config['XSimGCL']
self.cl_rate = float(config['lambda'])
self.eps = float(config['eps'])
self.temp = float(config['tau'])
self.n_layers = int(config['n_layer'])
self.layer_cl = int(config['l_star'])
self.model = XSimGCL_Encoder(self.data, self.emb_size, self.eps, self.n_layers,self.layer_cl)
def train(self):
model = self.model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=self.lRate)
for epoch in range(self.maxEpoch):
for n, batch in enumerate(next_batch_pairwise(self.data, self.batch_size)):
user_idx, pos_idx, neg_idx = batch
rec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)
user_emb, pos_item_emb, neg_item_emb = rec_user_emb[user_idx], rec_item_emb[pos_idx], rec_item_emb[neg_idx]
rec_loss = bpr_loss(user_emb, pos_item_emb, neg_item_emb)
cl_loss = self.cl_rate * self.cal_cl_loss([user_idx,pos_idx],rec_user_emb,cl_user_emb,rec_item_emb,cl_item_emb)
batch_loss = rec_loss + l2_reg_loss(self.reg, user_emb, pos_item_emb) + cl_loss
# Backward and optimize
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
if n % 100==0 and n>0:
print('training:', epoch + 1, 'batch', n, 'rec_loss:', rec_loss.item(), 'cl_loss', cl_loss.item())
with torch.no_grad():
self.user_emb, self.item_emb = self.model()
self.fast_evaluation(epoch)
self.user_emb, self.item_emb = self.best_user_emb, self.best_item_emb
def cal_cl_loss(self, idx, user_view1,user_view2,item_view1,item_view2):
u_idx = torch.unique(torch.Tensor(idx[0]).type(torch.long)).cuda()
i_idx = torch.unique(torch.Tensor(idx[1]).type(torch.long)).cuda()
user_cl_loss = InfoNCE(user_view1[u_idx], user_view2[u_idx], self.temp)
item_cl_loss = InfoNCE(item_view1[i_idx], item_view2[i_idx], self.temp)
return user_cl_loss + item_cl_loss
def save(self):
with torch.no_grad():
self.best_user_emb, self.best_item_emb = self.model.forward()
def predict(self, u):
u = self.data.get_user_id(u)
score = torch.matmul(self.user_emb[u], self.item_emb.transpose(0, 1))
return score.cpu().numpy()
class XSimGCL_Encoder(nn.Module):
def __init__(self, data, emb_size, eps, n_layers, layer_cl):
super(XSimGCL_Encoder, self).__init__()
self.data = data
self.eps = eps
self.emb_size = emb_size
self.n_layers = n_layers
self.layer_cl = layer_cl
self.norm_adj = data.norm_adj
self.embedding_dict = self._init_model()
self.sparse_norm_adj = TorchGraphInterface.convert_sparse_mat_to_tensor(self.norm_adj).cuda()
def _init_model(self):
initializer = nn.init.xavier_uniform_
embedding_dict = nn.ParameterDict({
'user_emb': nn.Parameter(initializer(torch.empty(self.data.user_num, self.emb_size))),
'item_emb': nn.Parameter(initializer(torch.empty(self.data.item_num, self.emb_size))),
})
return embedding_dict
def forward(self, perturbed=False):
ego_embeddings = torch.cat([self.embedding_dict['user_emb'], self.embedding_dict['item_emb']], 0)
all_embeddings = []
all_embeddings_cl = ego_embeddings
for k in range(self.n_layers):
ego_embeddings = torch.sparse.mm(self.sparse_norm_adj, ego_embeddings)
if perturbed:
random_noise = torch.rand_like(ego_embeddings).cuda()
ego_embeddings += torch.sign(ego_embeddings) * F.normalize(random_noise, dim=-1) * self.eps
all_embeddings.append(ego_embeddings)
if k==self.layer_cl-1:
all_embeddings_cl = ego_embeddings
final_embeddings = torch.stack(all_embeddings, dim=1)
final_embeddings = torch.mean(final_embeddings, dim=1)
user_all_embeddings, item_all_embeddings = torch.split(final_embeddings, [self.data.user_num, self.data.item_num])
user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embeddings_cl, [self.data.user_num, self.data.item_num])
if perturbed:
return user_all_embeddings, item_all_embeddings,user_all_embeddings_cl, item_all_embeddings_cl
return user_all_embeddings, item_all_embeddings