Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: add DirectAU and fix some bugs #74

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions recbole_gnn/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,33 @@
from tqdm import tqdm
from torch_geometric.utils import degree


from recbole.data.dataset import SequentialDataset
from recbole.data.dataset import Dataset as RecBoleDataset
from recbole.utils import set_color, FeatureSource

import recbole
import pickle
from recbole.utils import ensure_dir


class GeneralGraphDataset(RecBoleDataset):
def __init__(self, config):
super().__init__(config)

if recbole.__version__ == "1.1.1":

def save(self):
"""Saving this :class:`Dataset` object to :attr:`config['checkpoint_dir']`."""
save_dir = self.config["checkpoint_dir"]
ensure_dir(save_dir)
file = os.path.join(save_dir, f'{self.config["dataset"]}-{self.__class__.__name__}.pth')
self.logger.info(
set_color("Saving filtered dataset into ", "pink") + f"[{file}]"
)
with open(file, "wb") as f:
pickle.dump(self, f)

def get_norm_adj_mat(self):
r"""Get the normalized interaction matrix of users and items.
Construct the square matrix from the training data and normalize it
Expand Down Expand Up @@ -101,6 +119,7 @@ def build(self):
dataset.session_graph_construction()
return datasets


class MultiBehaviorDataset(SessionGraphDataset):

def session_graph_construction(self):
Expand All @@ -113,7 +132,7 @@ def session_graph_construction(self):
# To be compatible with existing datasets
item_behavior_seq = torch.tensor([0] * len(item_seq))
self.behavior_id_field = 'behavior_id'
self.field2id_token[self.behavior_id_field] = {0:'interaction'}
self.field2id_token[self.behavior_id_field] = {0: 'interaction'}
else:
item_behavior_seq = self.inter_feat[self.item_list_length_field]

Expand Down Expand Up @@ -152,6 +171,7 @@ def session_graph_construction(self):
'alias_inputs': alias_inputs
}


class LESSRDataset(SessionGraphDataset):
def session_graph_construction(self):
self.logger.info('Constructing LESSR session graphs.')
Expand Down Expand Up @@ -199,14 +219,14 @@ def reverse_session(self):
item_seq = self.inter_feat[self.item_id_list_field]
item_seq_len = self.inter_feat[self.item_list_length_field]
for i in tqdm(range(item_seq.shape[0])):
item_seq[i,:item_seq_len[i]] = item_seq[i,:item_seq_len[i]].flip(dims=[0])
item_seq[i, :item_seq_len[i]] = item_seq[i, :item_seq_len[i]].flip(dims=[0])

def bidirectional_edge(self, edge_index):
seq_len = edge_index.shape[1]
ed = edge_index.T
ed2 = edge_index.T.flip(dims=[1])
idc = ed.unsqueeze(1).expand(-1, seq_len, 2) == ed2.unsqueeze(0).expand(seq_len, -1, 2)
return torch.logical_and(idc[:,:,0], idc[:,:,1]).any(dim=-1)
return torch.logical_and(idc[:, :, 0], idc[:, :, 1]).any(dim=-1)

def session_graph_construction(self):
self.logger.info('Constructing session graphs.')
Expand Down Expand Up @@ -276,9 +296,10 @@ class SocialDataset(GeneralGraphDataset):
net_feat (pandas.DataFrame): Internal data structure stores the users' social network relations.
It's loaded from file ``.net``.
"""

def __init__(self, config):
super().__init__(config)

def _get_field_from_config(self):
super()._get_field_from_config()

Expand Down Expand Up @@ -410,4 +431,4 @@ def net_matrix(self, form='coo', value_field=None):
Returns:
scipy.sparse: Sparse matrix in form ``coo`` or ``csr``.
"""
return self._create_sparse_matrix(self.net_feat, self.net_src_field, self.net_tgt_field, form, value_field)
return self._create_sparse_matrix(self.net_feat, self.net_src_field, self.net_tgt_field, form, value_field)
1 change: 1 addition & 0 deletions recbole_gnn/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from recbole_gnn.model.general_recommender.lightgcl import LightGCL
from recbole_gnn.model.general_recommender.simgcl import SimGCL
from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL
from recbole_gnn.model.general_recommender.directau import DirectAU
120 changes: 120 additions & 0 deletions recbole_gnn/model/general_recommender/directau.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# r"""
# DiretAU
# ################################################
# Reference:
# Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022.

# Reference code:
# https://github.com/THUwangcy/DirectAU
# """

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType
from recbole.model.general_recommender import BPR
from recbole_gnn.model.general_recommender import LightGCN

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender


class DirectAU(GeneralGraphRecommender):
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(DirectAU, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config['embedding_size']
self.gamma = config['gamma']
self.encoder_name = config['encoder']

# define encoder
if self.encoder_name == 'MF':
self.encoder = MFEncoder(config, dataset)
elif self.encoder_name == 'LightGCN':
self.encoder = LGCNEncoder(config, dataset)
else:
raise ValueError('Non-implemented Encoder.')

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_normal_initialization)

def forward(self, user, item):
user_e, item_e = self.encoder(user, item)
return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1)

@staticmethod
def alignment(x, y, alpha=2):
return (x - y).norm(p=2, dim=1).pow(alpha).mean()

@staticmethod
def uniformity(x, t=2):
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()

def calculate_loss(self, interaction):
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_e, item_e = self.forward(user, item)
align = self.alignment(user_e, item_e)
uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2

return align, uniform

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user_e = self.user_embedding(user)
item_e = self.item_embedding(item)
return torch.mul(user_e, item_e).sum(dim=1)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.encoder_name == 'LightGCN':
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings()
user_e = self.restore_user_e[user]
all_item_e = self.restore_item_e
else:
user_e = self.encoder.user_embedding(user)
all_item_e = self.encoder.item_embedding.weight
score = torch.matmul(user_e, all_item_e.transpose(0, 1))
return score.view(-1)


class MFEncoder(BPR):
def __init__(self, config, dataset):
super(MFEncoder, self).__init__(config, dataset)

def forward(self, user_id, item_id):
return super().forward(user_id, item_id)

def get_all_embeddings(self):
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
return user_embeddings, item_embeddings


class LGCNEncoder(LightGCN):
def __init__(self, config, dataset):
super(LGCNEncoder, self).__init__(config, dataset)

def forward(self, user_id, item_id):
user_all_embeddings, item_all_embeddings = self.get_all_embeddings()
u_embed = user_all_embeddings[user_id]
i_embed = item_all_embeddings[item_id]
return u_embed, i_embed

def get_all_embeddings(self):
return super().forward()
7 changes: 7 additions & 0 deletions recbole_gnn/properties/model/DirectAU.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
embedding_size: 64
encoder: "MF" # "MF" or "lightGCN"
gamma: 0.5
weight_decay: 1e-6
train_batch_size: 256

# n_layers: 3 # needed for LightGCN
1 change: 1 addition & 0 deletions recbole_gnn/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def objective_function(config_dict=None, config_file_list=None, saved=True):
test_result = trainer.evaluate(test_data, load_best_model=saved)

return {
'model': config['model'],
'best_valid_score': best_valid_score,
'valid_score_bigger': config['valid_metric_bigger'],
'best_valid_result': best_valid_result,
Expand Down
8 changes: 7 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_simgcl(self):
'model': 'SimGCL'
}
quick_test(config_dict)

def test_xsimgcl(self):
config_dict = {
'model': 'XSimGCL'
Expand All @@ -73,6 +73,12 @@ def test_lightgcl(self):
}
quick_test(config_dict)

def test_directau(self):
config_dict = {
'model': 'DirectAU'
}
quick_test(config_dict)


class TestSequentialRecommender(unittest.TestCase):
def test_gru4rec(self):
Expand Down