Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: add XSimGCL #72

Merged
merged 4 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pytest
pip install dgl
pip install dgl==0.9.1
pip install torch==${{ matrix.torch-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install torch-scatter==2.0.9 torch-sparse==0.6.15 torch-cluster==1.6.0 torch-spline-conv==1.2.1 torch-geometric -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install recbole==1.1.1
conda install -c conda-forge faiss-cpu
# Use "python -m pytest" instead of "pytest" to fix imports
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ We list currently supported models according to category:
* **[HMLET](recbole_gnn/model/general_recommender/hmlet.py)** from Kong *et al.*: [Linear, or Non-Linear, That is the Question!](https://arxiv.org/abs/2111.07265) (WSDM 2022).
* **[NCL](recbole_gnn/model/general_recommender/ncl.py)** from Lin *et al.*: [Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning](https://arxiv.org/abs/2202.06200) (TheWebConf 2022).
* **[SimGCL](recbole_gnn/model/general_recommender/simgcl.py)** from Yu *et al.*: [Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2112.08679) (SIGIR 2022).
* **[XSimGCL](recbole_gnn/model/general_recommender/xsimgcl.py)** from Yu *et al.*: [XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2209.02544) (TKDE 2023).

**Sequential Recommendation**:

Expand Down
8 changes: 5 additions & 3 deletions recbole_gnn/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import recbole
from recbole.config.configurator import Config as RecBole_Config
from recbole.utils import ModelType as RecBoleModelType

Expand All @@ -16,7 +17,8 @@ def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=
config_file_list (list of str): the external config file, it allows multiple config files, default is None.
config_dict (dict): the external parameter dictionaries, default is None.
"""
self.compatibility_settings()
if recbole.__version__ == "1.1.1":
self.compatibility_settings()
super(Config, self).__init__(model, dataset, config_file_list, config_dict)

def compatibility_settings(self):
Expand Down Expand Up @@ -59,7 +61,7 @@ def _get_model_and_dataset(self, model, dataset):
final_dataset = dataset

return final_model, final_model_class, final_dataset

def _load_internal_config_dict(self, model, model_class, dataset):
super()._load_internal_config_dict(model, model_class, dataset)
current_path = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -75,4 +77,4 @@ def _load_internal_config_dict(self, model, model_class, dataset):
if self.internal_config_dict['MODEL_TYPE'] == RecBoleModelType.SEQUENTIAL:
self._update_internal_config_dict(sequential_base_init)
if self.internal_config_dict['MODEL_TYPE'] == ModelType.SOCIAL:
self._update_internal_config_dict(social_base_init)
self._update_internal_config_dict(social_base_init)
2 changes: 2 additions & 0 deletions recbole_gnn/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from recbole_gnn.model.general_recommender.ncl import NCL
from recbole_gnn.model.general_recommender.ngcf import NGCF
from recbole_gnn.model.general_recommender.sgl import SGL
from recbole_gnn.model.general_recommender.simgcl import SimGCL
from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL
90 changes: 90 additions & 0 deletions recbole_gnn/model/general_recommender/xsimgcl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
r"""
XSimGCL
################################################
Reference:
Junliang Yu, Xin Xia, Tong Chen, Lizhen Cui, Nguyen Quoc Viet Hung, Hongzhi Yin. "XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation" in TKDE 2023.

Reference code:
https://github.com/Coder-Yu/SELFRec/blob/main/model/graph/XSimGCL.py
"""


import torch
import torch.nn.functional as F

from recbole_gnn.model.general_recommender import LightGCN


class XSimGCL(LightGCN):
def __init__(self, config, dataset):
super(XSimGCL, self).__init__(config, dataset)

self.cl_rate = config['lambda']
self.eps = config['eps']
self.temperature = config['temperature']
self.layer_cl = config['layer_cl']

def forward(self, perturbed=False):
all_embs = self.get_ego_embeddings()
all_embs_cl = all_embs
embeddings_list = []

for layer_idx in range(self.n_layers):
all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
if perturbed:
random_noise = torch.rand_like(all_embs, device=all_embs.device)
all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
embeddings_list.append(all_embs)
if layer_idx == self.layer_cl - 1:
all_embs_cl = all_embs
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embs_cl, [self.n_users, self.n_items])
if perturbed:
return user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl
return user_all_embeddings, item_all_embeddings

def calculate_cl_loss(self, x1, x2):
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
pos_score = (x1 * x2).sum(dim=-1)
pos_score = torch.exp(pos_score / self.temperature)
ttl_score = torch.matmul(x1, x2.transpose(0, 1))
ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
return -torch.log(pos_score / ttl_score).mean()

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl = self.forward(perturbed=True)
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)

# calculate regularization Loss
u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)
reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)

user = torch.unique(interaction[self.USER_ID])
pos_item = torch.unique(interaction[self.ITEM_ID])

# calculate CL Loss
user_cl_loss = self.calculate_cl_loss(user_all_embeddings[user], user_all_embeddings_cl[user])
item_cl_loss = self.calculate_cl_loss(item_all_embeddings[pos_item], item_all_embeddings_cl[pos_item])

return mf_loss, self.reg_weight * reg_loss, self.cl_rate * (user_cl_loss + item_cl_loss)
9 changes: 9 additions & 0 deletions recbole_gnn/properties/model/XSimGCL.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
embedding_size: 64
n_layers: 2
reg_weight: 0.0001

lambda: 0.1
eps: 0.2
temperature: 0.2
layer_cl: 1
require_pow: True
22 changes: 12 additions & 10 deletions results/general/ml-1m.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,17 @@ embedding_size: 64

# Evaluation Results

| Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 |
| -------------------- | --------- | ------ | ------- | ------ | ------------ |
| **BPR** | 0.1776 | 0.4187 | 0.2401 | 0.7199 | 0.1779 |
| **NeuMF** | 0.1651 | 0.4020 | 0.2271 | 0.7029 | 0.1700 |
| **NGCF** | 0.1814 | 0.4354 | 0.2508 | 0.7239 | 0.1850 |
| **LightGCN** | 0.1861 | 0.4388 | 0.2538 | 0.7330 | 0.1863 |
| **SGL** | 0.1889 | 0.4315 | 0.2505 | 0.7392 | 0.1843 |
| **HMLET** | 0.1847 | 0.4297 | 0.2490 | 0.7305 | 0.1836 |
| **NCL** | 0.2021 | 0.4599 | 0.2702 | 0.7565 | 0.1962 |
| **SimGCL** | 0.2029 | 0.4550 | 0.2667 | 0.7640 | 0.1933 |
| Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 |
| ------------ | --------- | ------ | ------- | ------ | ------------ |
| **BPR** | 0.1776 | 0.4187 | 0.2401 | 0.7199 | 0.1779 |
| **NeuMF** | 0.1651 | 0.4020 | 0.2271 | 0.7029 | 0.1700 |
| **NGCF** | 0.1814 | 0.4354 | 0.2508 | 0.7239 | 0.1850 |
| **LightGCN** | 0.1861 | 0.4388 | 0.2538 | 0.7330 | 0.1863 |
| **SGL** | 0.1889 | 0.4315 | 0.2505 | 0.7392 | 0.1843 |
| **HMLET** | 0.1847 | 0.4297 | 0.2490 | 0.7305 | 0.1836 |
| **NCL** | 0.2021 | 0.4599 | 0.2702 | 0.7565 | 0.1962 |
| **SimGCL** | 0.2029 | 0.4550 | 0.2667 | 0.7640 | 0.1933 |
| **XSimGCL** | 0.2116 | 0.4638 | 0.2750 | 0.7743 | 0.1987 |

# Hyper-parameters

Expand All @@ -69,3 +70,4 @@ embedding_size: 64
| **HMLET** | learning_rate=0.002<br />n_layers=4<br />activation_function=leakyrelu | learning_rate choice [0.002, 0.001, 0.0005]<br/>n_layers choice [3, 4]<br/>activation_function choice ['elu', 'leakyrelu'] |
| **NCL** | learning_rate=0.002<br />n_layers=3<br />reg_weight=0.0001<br />ssl_temp=0.1<br />ssl_reg=1e-06<br />hyper_layers=1<br />alpha=1.5 | learning_rate choice [0.002]<br/>n_layers choice [3]<br/>reg_weight choice [1e-4]<br/>ssl_temp choice [0.1, 0.05]<br/>ssl_reg choice [1e-7, 1e-6]<br/>hyper_layers choice [1]<br/>alpha choice [1, 0.8, 1.5] |
| **SimGCL** | learning_rate=0.002<br />n_layers=2<br />reg_weight=0.0001<br />temperature=0.05<br />lambda=1e-5<br />eps=0.1 | learning_rate choice [0.002]<br/>n_layers choice [2, 3]<br/>reg_weight choice [1e-4]<br/>temperature choice [0.05, 0.1, 0.2]<br/>lambda choice [1e-5, 1e-6, 1e-7, 0.005, 0.01, 0.05]<br/>eps choice [0.1, 0.2] |
| **XSimGCL** | learning_rate=0.002<br />n_layers=2<br />reg_weight=0.0001<br />temperature=0.2<br />lambda=0.1<br />eps=0.2<br />layer_cl=1 | learning_rate choice [0.002]<br/>n_layers choice [2, 3]<br/>reg_weight choice [1e-4]<br/>temperature choice [0.05, 0.1, 0.2]<br/>lambda choice [1e-5, 1e-6, 1e-7, 1e-4, 0.005, 0.01, 0.05, 0.1]<br/>eps choice [0.1, 0.2]<br/>layer_cl choice [1] |
6 changes: 6 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def test_simgcl(self):
'model': 'SimGCL'
}
quick_test(config_dict)

def test_xsimgcl(self):
config_dict = {
'model': 'XSimGCL'
}
quick_test(config_dict)


class TestSequentialRecommender(unittest.TestCase):
Expand Down