From a2614b1c826b945a9420733f6c68e62bb2c37165 Mon Sep 17 00:00:00 2001 From: Wicknight <1462034631@qq.com> Date: Mon, 4 Apr 2022 14:02:40 +0800 Subject: [PATCH 1/5] FEA: Add CoNet model --- .../model/cross_domain_recommender/conet.py | 191 ++++++++++++++++++ recbole_cdr/properties/model/CoNet.yaml | 3 + 2 files changed, 194 insertions(+) create mode 100644 recbole_cdr/model/cross_domain_recommender/conet.py create mode 100644 recbole_cdr/properties/model/CoNet.yaml diff --git a/recbole_cdr/model/cross_domain_recommender/conet.py b/recbole_cdr/model/cross_domain_recommender/conet.py new file mode 100644 index 0000000..ddcb6ba --- /dev/null +++ b/recbole_cdr/model/cross_domain_recommender/conet.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +# @Time : 2022/3/30 +# @Author : Gaowei Zhang +# @Email : 1462034631@qq.com + +r""" +BiTGCF +################################################ +Reference: + Guangneng Hu et al. "CoNet: Collaborative Cross Networks for Cross-Domain Recommendation." in CIKM 2018. +""" + +import numpy as np + +import torch +import torch.nn as nn + +from recbole_cdr.model.crossdomain_recommender import CrossDomainRecommender +from recbole.model.init import xavier_normal_initialization, xavier_normal_ +from recbole.utils import InputType + + +class CoNet(CrossDomainRecommender): + r""" + + """ + input_type = InputType.POINTWISE + + def __init__(self, config, dataset): + super(CoNet, self).__init__(config, dataset) + + # load dataset info + self.SOURCE_LABEL = dataset.source_domain_dataset.label_field + self.TARGET_LABEL = dataset.target_domain_dataset.label_field + + # load parameters info + self.device = config['device'] + + # load parameters info + self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN + self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization + self.cross_layers = config["mlp_hidden_size"] #list type: the list of hidden lyaers size + + # define layers and loss + self.source_user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users, embedding_dim=self.latent_dim) + self.target_user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users, embedding_dim=self.latent_dim) + + self.source_item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items, embedding_dim=self.latent_dim) + self.target_item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items, embedding_dim=self.latent_dim) + + self.loss = nn.BCELoss() + + with torch.no_grad(): + self.source_user_embedding.weight[self.overlapped_num_users: self.target_num_users].fill_(0) + self.source_item_embedding.weight[self.overlapped_num_items: self.target_num_items].fill_(0) + + self.target_user_embedding.weight[self.target_num_users:].fill_(0) + self.target_item_embedding.weight[self.target_num_items:].fill_(0) + + self.source_crossunit = self.cross_units([2 * self.latent_dim] + self.cross_layers) + self.source_outputunit = nn.Sequential( + nn.Linear(self.cross_layers[-1], 1), + nn.Sigmoid() + ) + + self.target_crossunit = self.cross_units([2 * self.latent_dim] + self.cross_layers) + self.target_outputunit = nn.Sequential( + nn.Linear(self.cross_layers[-1], 1), + nn.Sigmoid() + ) + + self.crossparas = self.cross_parameters([2 * self.latent_dim] + self.cross_layers) + + # parameters initialization + self.apply(xavier_normal_initialization) + + def cross_units(self, cross_layers): + cross_modules = [] + for i, (d_in, d_out) in enumerate(zip(cross_layers[:-1], cross_layers[1:])): + module = (nn.Linear(d_in, d_out), nn.ReLU()) + cross_modules.append(module) + return cross_modules + + def cross_parameters(self, cross_layers): + cross_paras = [] + for i, (d_in, d_out) in enumerate(zip(cross_layers[:-1], cross_layers[1:])): + para = nn.Parameter(torch.empty(d_in, d_out)) + xavier_normal_(para) + cross_paras.append(para) + return cross_paras + + def forward(self, source_user, source_item, target_user, target_item): + source_user_embedding = self.source_user_embedding(source_user) + source_item_embedding = self.source_item_embedding(source_item) + target_user_embedding = self.target_user_embedding(target_user) + target_item_embedding = self.target_item_embedding(target_item) + + source_overlapuser_idx = set([idx for idx in range(source_user.shape[0]) if source_user[idx] < self.overlapped_num_users]) + source_overlapitem_idx = set([idx for idx in range(source_item.shape[0]) if source_item[idx] < self.overlapped_num_items]) + source_overlap_idx = source_overlapuser_idx.union(source_overlapitem_idx) + target_overlapuser_idx = set([idx for idx in range(target_user.shape[0]) if target_user[idx] < self.overlapped_num_users]) + target_overlapitem_idx = set([idx for idx in range(target_item.shape[0]) if target_item[idx] < self.overlapped_num_items]) + target_overlap_idx = target_overlapuser_idx.union(target_overlapitem_idx) + overlap_idx = torch.from_numpy(np.array(list(source_overlap_idx.intersection(target_overlap_idx)))).to(self.device) + + source_crossinput = torch.cat([source_user_embedding, source_item_embedding], dim=1).to(self.device) + target_crossinput = torch.cat([target_user_embedding, target_item_embedding], dim=1).to(self.device) + + for i, (source_cross_module, target_cross_module) in enumerate(zip(self.source_crossunit, self.target_crossunit)): + source_fc_module, source_act_module = source_cross_module + source_fc_module = source_fc_module.to(self.device) + source_act_module = source_act_module.to(self.device) + cross_para = self.crossparas[i].to(self.device) + target_fc_module, target_act_module = target_cross_module + target_fc_module = target_fc_module.to(self.device) + target_act_module = target_act_module.to(self.device) + source_crossoutput = source_fc_module(source_crossinput) + source_crossoutput[overlap_idx] = source_crossoutput[overlap_idx] + torch.mm(target_crossinput, cross_para)[overlap_idx] + source_crossoutput = source_act_module(source_crossoutput) + + target_crossoutput = target_fc_module(target_crossinput) + target_crossoutput[overlap_idx] = target_crossoutput[overlap_idx] + torch.mm(source_crossinput, cross_para)[overlap_idx] + target_crossoutput = target_act_module(target_crossoutput) + + source_crossinput = source_crossoutput + target_crossinput = target_crossoutput + + source_out = self.source_outputunit(source_crossoutput) + target_out = self.target_outputunit(target_crossoutput) + + return source_out, target_out + + def calculate_loss(self, interaction): + source_user = interaction[self.SOURCE_USER_ID] + source_item = interaction[self.SOURCE_ITEM_ID] + source_label = interaction[self.SOURCE_LABEL] + + target_user = interaction[self.TARGET_USER_ID] + target_item = interaction[self.TARGET_ITEM_ID] + target_label = interaction[self.TARGET_LABEL] + + p_source, p_target = self.forward(source_user, source_item, target_user, target_item) + + loss_s = self.loss(p_source, source_label) + loss_t = self.loss(p_target, target_label) + + reg_loss = 0 + for para in self.crossparas: + reg_loss += torch.norm(para) + loss = loss_s + loss_t + reg_loss + + return loss + + def predict(self, interaction): + user = interaction[self.TARGET_USER_ID] + item = interaction[self.TARGET_ITEM_ID] + user_e = self.target_user_embedding(user) + item_e = self.target_item_embedding(item) + input = torch.cat([user_e, item_e], dim=1) + + for i, target_cross_module in enumerate(self.target_crossunit): + target_fc_module, target_act_module = target_cross_module + output = target_act_module(target_fc_module(input)) + + input = output + + p = self.target_outputunit(output) + return p + + def full_sort_predict(self, interaction): + user = interaction[self.TARGET_USER_ID] + user_e = self.target_user_embedding(user) + user_num = user_e.shape[0] + all_item_e = self.target_item_embedding.weight[:self.target_num_items] + item_num = all_item_e.shape[0] + all_user_e = user_e.repeat(1, item_num).view(-1, self.latent_dim) + user_e_list = torch.split(all_user_e, [item_num]*user_num) + score_list = [] + for u_embed in user_e_list: + input = torch.cat([u_embed, all_item_e], dim=1) + for i, target_cross_module in enumerate(self.target_crossunit): + target_fc_module, target_act_module = target_cross_module # [d_in, d_out] + output = target_act_module(target_fc_module(input)) + + input = output + + p = self.target_outputunit(output) + print(p.shape) + score_list.append(p) + score = torch.cat(score_list, dim=1).transpose(0, 1) + return score \ No newline at end of file diff --git a/recbole_cdr/properties/model/CoNet.yaml b/recbole_cdr/properties/model/CoNet.yaml new file mode 100644 index 0000000..c09a516 --- /dev/null +++ b/recbole_cdr/properties/model/CoNet.yaml @@ -0,0 +1,3 @@ +embedding_size: 64 +reg_weight: 1e-2 +mlp_hidden_size: [128,64] \ No newline at end of file From 0ac32e14077cadb13539c5a712c9a9f656e26c1f Mon Sep 17 00:00:00 2001 From: Wicknight <1462034631@qq.com> Date: Mon, 4 Apr 2022 14:06:16 +0800 Subject: [PATCH 2/5] Format: Change model config --- recbole_cdr/model/cross_domain_recommender/conet.py | 2 +- recbole_cdr/properties/model/CoNet.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/recbole_cdr/model/cross_domain_recommender/conet.py b/recbole_cdr/model/cross_domain_recommender/conet.py index ddcb6ba..be99560 100644 --- a/recbole_cdr/model/cross_domain_recommender/conet.py +++ b/recbole_cdr/model/cross_domain_recommender/conet.py @@ -4,7 +4,7 @@ # @Email : 1462034631@qq.com r""" -BiTGCF +CoNet ################################################ Reference: Guangneng Hu et al. "CoNet: Collaborative Cross Networks for Cross-Domain Recommendation." in CIKM 2018. diff --git a/recbole_cdr/properties/model/CoNet.yaml b/recbole_cdr/properties/model/CoNet.yaml index c09a516..a5d8d32 100644 --- a/recbole_cdr/properties/model/CoNet.yaml +++ b/recbole_cdr/properties/model/CoNet.yaml @@ -1,3 +1,3 @@ embedding_size: 64 reg_weight: 1e-2 -mlp_hidden_size: [128,64] \ No newline at end of file +mlp_hidden_size: [256,128,64] \ No newline at end of file From 5f650aa725ad36afef84e7d86e414bc27a4841f6 Mon Sep 17 00:00:00 2001 From: Wicknight <1462034631@qq.com> Date: Mon, 4 Apr 2022 14:07:35 +0800 Subject: [PATCH 3/5] Format: Remove debug information --- recbole_cdr/model/cross_domain_recommender/conet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/recbole_cdr/model/cross_domain_recommender/conet.py b/recbole_cdr/model/cross_domain_recommender/conet.py index be99560..a78fe0f 100644 --- a/recbole_cdr/model/cross_domain_recommender/conet.py +++ b/recbole_cdr/model/cross_domain_recommender/conet.py @@ -185,7 +185,6 @@ def full_sort_predict(self, interaction): input = output p = self.target_outputunit(output) - print(p.shape) score_list.append(p) score = torch.cat(score_list, dim=1).transpose(0, 1) return score \ No newline at end of file From 70e3a3b50d7168ea73af6d40d7327a295c082514 Mon Sep 17 00:00:00 2001 From: Wicknight <1462034631@qq.com> Date: Wed, 6 Apr 2022 15:27:58 +0800 Subject: [PATCH 4/5] FIX: Bug in forward --- .../model/cross_domain_recommender/conet.py | 90 ++++++++++++++----- 1 file changed, 70 insertions(+), 20 deletions(-) diff --git a/recbole_cdr/model/cross_domain_recommender/conet.py b/recbole_cdr/model/cross_domain_recommender/conet.py index a78fe0f..68180bc 100644 --- a/recbole_cdr/model/cross_domain_recommender/conet.py +++ b/recbole_cdr/model/cross_domain_recommender/conet.py @@ -33,6 +33,15 @@ def __init__(self, config, dataset): self.SOURCE_LABEL = dataset.source_domain_dataset.label_field self.TARGET_LABEL = dataset.target_domain_dataset.label_field + assert self.overlapped_num_items == 1 or self.overlapped_num_users == 1, \ + "CoNet model only support user overlapped or item overlapped dataset! " + if self.overlapped_num_users > 1: + self.mode = 'overlap_users' + elif self.overlapped_num_items > 1: + self.mode = 'overlap_items' + else: + self.mode = 'non_overlap' + # load parameters info self.device = config['device'] @@ -89,24 +98,61 @@ def cross_parameters(self, cross_layers): cross_paras.append(para) return cross_paras - def forward(self, source_user, source_item, target_user, target_item): - source_user_embedding = self.source_user_embedding(source_user) - source_item_embedding = self.source_item_embedding(source_item) - target_user_embedding = self.target_user_embedding(target_user) - target_item_embedding = self.target_item_embedding(target_item) + def source_forward(self, user, item): + source_user_embedding = self.source_user_embedding(user) + source_item_embedding = self.source_item_embedding(item) + target_user_embedding = self.target_user_embedding(user) + target_item_embedding = self.target_item_embedding(item) + source_crossinput = torch.cat([source_user_embedding, source_item_embedding], dim=1).to(self.device) + target_crossinput = torch.cat([target_user_embedding, target_item_embedding], dim=1).to(self.device) - source_overlapuser_idx = set([idx for idx in range(source_user.shape[0]) if source_user[idx] < self.overlapped_num_users]) - source_overlapitem_idx = set([idx for idx in range(source_item.shape[0]) if source_item[idx] < self.overlapped_num_items]) - source_overlap_idx = source_overlapuser_idx.union(source_overlapitem_idx) - target_overlapuser_idx = set([idx for idx in range(target_user.shape[0]) if target_user[idx] < self.overlapped_num_users]) - target_overlapitem_idx = set([idx for idx in range(target_item.shape[0]) if target_item[idx] < self.overlapped_num_items]) - target_overlap_idx = target_overlapuser_idx.union(target_overlapitem_idx) - overlap_idx = torch.from_numpy(np.array(list(source_overlap_idx.intersection(target_overlap_idx)))).to(self.device) + if self.mode == 'overlap_users': + overlap_idx = user > self.overlapped_num_users + else: + overlap_idx = item > self.overlapped_num_items + for i, (source_cross_module, target_cross_module) in enumerate( + zip(self.source_crossunit, self.target_crossunit)): + source_fc_module, source_act_module = source_cross_module + source_fc_module = source_fc_module.to(self.device) + source_act_module = source_act_module.to(self.device) + cross_para = self.crossparas[i].to(self.device) + target_fc_module, target_act_module = target_cross_module + target_fc_module = target_fc_module.to(self.device) + target_act_module = target_act_module.to(self.device) + + source_crossoutput = source_fc_module(source_crossinput) + source_crossoutput[overlap_idx] = source_crossoutput[overlap_idx] + torch.mm(target_crossinput, cross_para)[ + overlap_idx] + source_crossoutput = source_act_module(source_crossoutput) + + target_crossoutput = target_fc_module(target_crossinput) + target_crossoutput[overlap_idx] = target_crossoutput[overlap_idx] + torch.mm(source_crossinput, cross_para)[ + overlap_idx] + target_crossoutput = target_act_module(target_crossoutput) + + source_crossinput = source_crossoutput + target_crossinput = target_crossoutput + + source_out = self.source_outputunit(source_crossoutput).squeeze() + + return source_out + + def target_forward(self, user, item): + source_user_embedding = self.source_user_embedding(user) + source_item_embedding = self.source_item_embedding(item) + target_user_embedding = self.target_user_embedding(user) + target_item_embedding = self.target_item_embedding(item) source_crossinput = torch.cat([source_user_embedding, source_item_embedding], dim=1).to(self.device) target_crossinput = torch.cat([target_user_embedding, target_item_embedding], dim=1).to(self.device) - for i, (source_cross_module, target_cross_module) in enumerate(zip(self.source_crossunit, self.target_crossunit)): + if self.mode == 'overlap_users': + overlap_idx = user > self.overlapped_num_users + else: + overlap_idx = item > self.overlapped_num_items + + for i, (source_cross_module, target_cross_module) in enumerate( + zip(self.source_crossunit, self.target_crossunit)): source_fc_module, source_act_module = source_cross_module source_fc_module = source_fc_module.to(self.device) source_act_module = source_act_module.to(self.device) @@ -114,21 +160,23 @@ def forward(self, source_user, source_item, target_user, target_item): target_fc_module, target_act_module = target_cross_module target_fc_module = target_fc_module.to(self.device) target_act_module = target_act_module.to(self.device) + source_crossoutput = source_fc_module(source_crossinput) - source_crossoutput[overlap_idx] = source_crossoutput[overlap_idx] + torch.mm(target_crossinput, cross_para)[overlap_idx] + source_crossoutput[overlap_idx] = source_crossoutput[overlap_idx] + torch.mm(target_crossinput, cross_para)[ + overlap_idx] source_crossoutput = source_act_module(source_crossoutput) target_crossoutput = target_fc_module(target_crossinput) - target_crossoutput[overlap_idx] = target_crossoutput[overlap_idx] + torch.mm(source_crossinput, cross_para)[overlap_idx] + target_crossoutput[overlap_idx] = target_crossoutput[overlap_idx] + torch.mm(source_crossinput, cross_para)[ + overlap_idx] target_crossoutput = target_act_module(target_crossoutput) source_crossinput = source_crossoutput target_crossinput = target_crossoutput - source_out = self.source_outputunit(source_crossoutput) - target_out = self.target_outputunit(target_crossoutput) + target_out = self.target_outputunit(target_crossoutput).squeeze() - return source_out, target_out + return target_out def calculate_loss(self, interaction): source_user = interaction[self.SOURCE_USER_ID] @@ -139,7 +187,8 @@ def calculate_loss(self, interaction): target_item = interaction[self.TARGET_ITEM_ID] target_label = interaction[self.TARGET_LABEL] - p_source, p_target = self.forward(source_user, source_item, target_user, target_item) + p_source = self.source_forward(source_user, source_item) + p_target = self.target_forward(target_user, target_item) loss_s = self.loss(p_source, source_label) loss_t = self.loss(p_target, target_label) @@ -165,6 +214,7 @@ def predict(self, interaction): input = output p = self.target_outputunit(output) + return p def full_sort_predict(self, interaction): @@ -179,7 +229,7 @@ def full_sort_predict(self, interaction): for u_embed in user_e_list: input = torch.cat([u_embed, all_item_e], dim=1) for i, target_cross_module in enumerate(self.target_crossunit): - target_fc_module, target_act_module = target_cross_module # [d_in, d_out] + target_fc_module, target_act_module = target_cross_module output = target_act_module(target_fc_module(input)) input = output From 636d730a4a205c06347a6c19a48d4dea0e90735f Mon Sep 17 00:00:00 2001 From: Wicknight <1462034631@qq.com> Date: Thu, 7 Apr 2022 19:24:36 +0800 Subject: [PATCH 5/5] Format: cross-units --- .../model/cross_domain_recommender/conet.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/recbole_cdr/model/cross_domain_recommender/conet.py b/recbole_cdr/model/cross_domain_recommender/conet.py index 68180bc..232c76b 100644 --- a/recbole_cdr/model/cross_domain_recommender/conet.py +++ b/recbole_cdr/model/cross_domain_recommender/conet.py @@ -88,7 +88,7 @@ def cross_units(self, cross_layers): for i, (d_in, d_out) in enumerate(zip(cross_layers[:-1], cross_layers[1:])): module = (nn.Linear(d_in, d_out), nn.ReLU()) cross_modules.append(module) - return cross_modules + return nn.ModuleList(cross_modules) def cross_parameters(self, cross_layers): cross_paras = [] @@ -96,7 +96,7 @@ def cross_parameters(self, cross_layers): para = nn.Parameter(torch.empty(d_in, d_out)) xavier_normal_(para) cross_paras.append(para) - return cross_paras + return nn.ModuleList(cross_paras) def source_forward(self, user, item): source_user_embedding = self.source_user_embedding(user) @@ -114,12 +114,12 @@ def source_forward(self, user, item): for i, (source_cross_module, target_cross_module) in enumerate( zip(self.source_crossunit, self.target_crossunit)): source_fc_module, source_act_module = source_cross_module - source_fc_module = source_fc_module.to(self.device) - source_act_module = source_act_module.to(self.device) - cross_para = self.crossparas[i].to(self.device) + source_fc_module = source_fc_module + source_act_module = source_act_module + cross_para = self.crossparas[i] target_fc_module, target_act_module = target_cross_module - target_fc_module = target_fc_module.to(self.device) - target_act_module = target_act_module.to(self.device) + target_fc_module = target_fc_module + target_act_module = target_act_module source_crossoutput = source_fc_module(source_crossinput) source_crossoutput[overlap_idx] = source_crossoutput[overlap_idx] + torch.mm(target_crossinput, cross_para)[ @@ -154,12 +154,12 @@ def target_forward(self, user, item): for i, (source_cross_module, target_cross_module) in enumerate( zip(self.source_crossunit, self.target_crossunit)): source_fc_module, source_act_module = source_cross_module - source_fc_module = source_fc_module.to(self.device) - source_act_module = source_act_module.to(self.device) - cross_para = self.crossparas[i].to(self.device) + source_fc_module = source_fc_module + source_act_module = source_act_module + cross_para = self.crossparas[i] target_fc_module, target_act_module = target_cross_module - target_fc_module = target_fc_module.to(self.device) - target_act_module = target_act_module.to(self.device) + target_fc_module = target_fc_module + target_act_module = target_act_module source_crossoutput = source_fc_module(source_crossinput) source_crossoutput[overlap_idx] = source_crossoutput[overlap_idx] + torch.mm(target_crossinput, cross_para)[