-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from PaddlePaddle/develop
merge from remote origin/develop
- Loading branch information
Showing
9 changed files
with
568 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v1.2.3 | ||
hooks: | ||
- id: trailing-whitespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# 个性化推荐中的多视角Simnet模型 | ||
|
||
## 介绍 | ||
在个性化推荐场景中,推荐系统给用户提供的项目(Item)列表通常是通过个性化的匹配模型计算出来的。在现实世界中,一个用户可能有很多个视角的特征,比如用户Id,年龄,项目的点击历史等。一个项目,举例来说,新闻资讯,也会有多种视角的特征比如新闻标题,新闻类别等。多视角Simnet模型是可以融合用户以及推荐项目的多个视角的特征并进行个性化匹配学习的一体化模型。这类模型在很多工业化的场景中都会被使用到,比如百度的Feed产品中。 | ||
|
||
## 数据集 | ||
目前,本项目实用机器生成的数据集来介绍多视角Simnet模型的概念,未来我们会逐渐加入真是世界中的数据集并在这个模型上进行效果验证。 | ||
|
||
## 模型 | ||
本项目的目标是提供一个在个性化匹配场景下利用Paddle搭建的模型。多视角Simnet模型包括多个编码器模块,每个编码器被用在不同的特征视角上。当前,项目中提供Bag-of-Embedding编码器,Temporal-Convolutional编码器,和Gated-Recurrent-Unit编码器。我们会逐渐加入稀疏特征场景下比较实用的编码器到这个项目中。模型的训练方法,当前采用的是Pairwise ranking模式进行训练,即针对一对具有关联的User-Item组合,随机实用一个Item作为负例进行排序学习。 | ||
|
||
## 训练 | ||
如下 | ||
如下命令行可以获得训练工具的具体选项,`python train.py -h`内容可以参考说明 | ||
```bash | ||
python train.py | ||
``` | ||
## 未来的工作 | ||
- 多种pairwise的损失函数会被加入到这个项目中。对于不同视角的特征,用户-项目之间的匹配关系可以使用不同的损失函数进行联合优化。整个模型会在真实数据中进行验证。 | ||
- 推理工具会被加入 | ||
- Parallel Executor选项会被加入 | ||
- 分布式训练能力会被加入 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Multi-view Simnet for Personalized recommendation | ||
|
||
## Introduction | ||
In personalized recommendation scenario, a user often is provided with several items from personalized interest matching model. In real world application, a user may have multiple views of features, say user-id, age, click-history of items, search queries. A item, e.g. news, may also have multiple views of features like news title, news category, images in news and so on. Multi-view Simnet is matching a model that combine users' and items' multiple views of features into one unified model. The model can be used in many industrial product like Baidu's feed news. The model is adapted from the paper A Multi-View Deep Learning(MV-DNN) Approach for Cross Domain User Modeling in Recommendation Systems, WWW 2015. The difference between our model and the MV-DNN is that we also consider multiple feature views of users. | ||
|
||
## Dataset | ||
Currently, synthetic dataset is provided for proof of concept and we aim to add more real world dataset in this project in the future. | ||
|
||
## Model | ||
This project aims to provide practical usage of Paddle in personalized matching scenario. The model provides several encoder modules for different views of features. Currently, Bag-of-Embedding encoder, Temporal-Convolutional encoder, Gated-Recurrent-Unit encoder are provided. We will add more practical encoder for sparse features commonly used in recommender systems. Training algorithms used in this model is pairwise ranking in that a negative item with multiple views will be sampled given a pair of positive user-item pair. | ||
|
||
## Train | ||
The command line options for training can be listed by `python train.py -h` | ||
```bash | ||
python train.py | ||
``` | ||
|
||
## Future work | ||
- Multiple types of pairwise loss will be added in this project. For different views of features between a user and an item, multiple losses will be supported. The model will be verified in real world dataset. | ||
- infer will be added | ||
- Parallel Executor will be added in this project | ||
- Distributed Training will be added |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle.fluid as fluid | ||
import paddle.fluid.layers.nn as nn | ||
import paddle.fluid.layers.tensor as tensor | ||
import paddle.fluid.layers.control_flow as cf | ||
import paddle.fluid.layers.io as io | ||
|
||
|
||
class BowEncoder(object): | ||
""" bow-encoder """ | ||
|
||
def __init__(self): | ||
self.param_name = "" | ||
|
||
def forward(self, emb): | ||
return nn.sequence_pool(input=emb, pool_type='sum') | ||
|
||
|
||
class CNNEncoder(object): | ||
""" cnn-encoder""" | ||
|
||
def __init__(self, | ||
param_name="cnn.w", | ||
win_size=3, | ||
ksize=128, | ||
act='tanh', | ||
pool_type='max'): | ||
self.param_name = param_name | ||
self.win_size = win_size | ||
self.ksize = ksize | ||
self.act = act | ||
self.pool_type = pool_type | ||
|
||
def forward(self, emb): | ||
return fluid.nets.sequence_conv_pool( | ||
input=emb, | ||
num_filters=self.ksize, | ||
filter_size=self.win_size, | ||
act=self.act, | ||
pool_type=self.pool_type, | ||
attr=self.param_name) | ||
|
||
|
||
class GrnnEncoder(object): | ||
""" grnn-encoder """ | ||
|
||
def __init__(self, param_name="grnn.w", hidden_size=128): | ||
self.param_name = args | ||
self.hidden_size = hidden_size | ||
|
||
def forward(self, emb): | ||
fc0 = nn.fc(input=emb, size=self.hidden_size * 3) | ||
gru_h = nn.dynamic_gru( | ||
input=emb, | ||
size=self.hidden_size, | ||
is_reverse=False, | ||
attr=self.param_name) | ||
return nn.sequence_pool(input=gru_h, pool_type='max') | ||
|
||
|
||
'''this is a very simple Encoder factory | ||
most default argument values are used''' | ||
|
||
|
||
class SimpleEncoderFactory(object): | ||
def __init__(self): | ||
pass | ||
|
||
''' create an encoder through create function ''' | ||
|
||
def create(self, enc_type, enc_hid_size): | ||
if enc_type == "bow": | ||
bow_encode = BowEncoder() | ||
return bow_encode | ||
elif enc_type == "cnn": | ||
cnn_encode = CNNEncoder(ksize=enc_hid_size) | ||
return cnn_encode | ||
elif enc_type == "gru": | ||
rnn_encode = GrnnEncoder(hidden_size=enc_hid_size) | ||
return rnn_encode | ||
|
||
|
||
class MultiviewSimnet(object): | ||
""" multi-view simnet """ | ||
|
||
def __init__(self, embedding_size, embedding_dim, hidden_size): | ||
self.embedding_size = embedding_size | ||
self.embedding_dim = embedding_dim | ||
self.emb_shape = [self.embedding_size, self.embedding_dim] | ||
self.hidden_size = hidden_size | ||
self.margin = 0.1 | ||
|
||
def set_query_encoder(self, encoders): | ||
self.query_encoders = encoders | ||
|
||
def set_title_encoder(self, encoders): | ||
self.title_encoders = encoders | ||
|
||
def get_correct(self, x, y): | ||
less = tensor.cast(cf.less_than(x, y), dtype='float32') | ||
correct = nn.reduce_sum(less) | ||
return correct | ||
|
||
def train_net(self): | ||
# input fields for query, pos_title, neg_title | ||
q_slots = [ | ||
io.data( | ||
name="q%d" % i, shape=[1], lod_level=1, dtype='int64') | ||
for i in range(len(self.query_encoders)) | ||
] | ||
pt_slots = [ | ||
io.data( | ||
name="pt%d" % i, shape=[1], lod_level=1, dtype='int64') | ||
for i in range(len(self.title_encoders)) | ||
] | ||
nt_slots = [ | ||
io.data( | ||
name="nt%d" % i, shape=[1], lod_level=1, dtype='int64') | ||
for i in range(len(self.title_encoders)) | ||
] | ||
|
||
# lookup embedding for each slot | ||
q_embs = [ | ||
nn.embedding( | ||
input=query, size=self.emb_shape, param_attr="emb.w") | ||
for query in q_slots | ||
] | ||
pt_embs = [ | ||
nn.embedding( | ||
input=title, size=self.emb_shape, param_attr="emb.w") | ||
for title in pt_slots | ||
] | ||
nt_embs = [ | ||
nn.embedding( | ||
input=title, size=self.emb_shape, param_attr="emb.w") | ||
for title in nt_slots | ||
] | ||
|
||
# encode each embedding field with encoder | ||
q_encodes = [ | ||
self.query_encoders[i].forward(emb) for i, emb in enumerate(q_embs) | ||
] | ||
pt_encodes = [ | ||
self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs) | ||
] | ||
nt_encodes = [ | ||
self.title_encoders[i].forward(emb) for i, emb in enumerate(nt_embs) | ||
] | ||
|
||
# concat multi view for query, pos_title, neg_title | ||
q_concat = nn.concat(q_encodes) | ||
pt_concat = nn.concat(pt_encodes) | ||
nt_concat = nn.concat(nt_encodes) | ||
|
||
# projection of hidden layer | ||
q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w') | ||
pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w') | ||
nt_hid = nn.fc(nt_concat, size=self.hidden_size, param_attr='t_fc.w') | ||
|
||
# cosine of hidden layers | ||
cos_pos = nn.cos_sim(q_hid, pt_hid) | ||
cos_neg = nn.cos_sim(q_hid, nt_hid) | ||
|
||
# pairwise hinge_loss | ||
loss_part1 = nn.elementwise_sub( | ||
tensor.fill_constant_batch_size_like( | ||
input=cos_pos, | ||
shape=[-1, 1], | ||
value=self.margin, | ||
dtype='float32'), | ||
cos_pos) | ||
|
||
loss_part2 = nn.elementwise_add(loss_part1, cos_neg) | ||
|
||
loss_part3 = nn.elementwise_max( | ||
tensor.fill_constant_batch_size_like( | ||
input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'), | ||
loss_part2) | ||
|
||
avg_cost = nn.mean(loss_part3) | ||
correct = self.get_correct(cos_pos, cos_neg) | ||
|
||
return q_slots + pt_slots + nt_slots, avg_cost, correct | ||
|
||
def pred_net(self, query_fields, pos_title_fields, neg_title_fields): | ||
q_slots = [ | ||
io.data( | ||
name="q%d" % i, shape=[1], lod_level=1, dtype='int64') | ||
for i in range(len(self.query_encoders)) | ||
] | ||
pt_slots = [ | ||
io.data( | ||
name="pt%d" % i, shape=[1], lod_level=1, dtype='int64') | ||
for i in range(len(self.title_encoders)) | ||
] | ||
# lookup embedding for each slot | ||
q_embs = [ | ||
nn.embedding( | ||
input=query, size=self.emb_shape, param_attr="emb.w") | ||
for query in q_slots | ||
] | ||
pt_embs = [ | ||
nn.embedding( | ||
input=title, size=self.emb_shape, param_attr="emb.w") | ||
for title in pt_slots | ||
] | ||
# encode each embedding field with encoder | ||
q_encodes = [ | ||
self.query_encoder[i].forward(emb) for i, emb in enumerate(q_embs) | ||
] | ||
pt_encodes = [ | ||
self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs) | ||
] | ||
# concat multi view for query, pos_title, neg_title | ||
q_concat = nn.concat(q_encodes) | ||
pt_concat = nn.concat(pt_encodes) | ||
# projection of hidden layer | ||
q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w') | ||
pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w') | ||
# cosine of hidden layers | ||
cos = nn.cos_sim(q_hid, pt_hid) | ||
return cos |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import random | ||
|
||
|
||
class Dataset: | ||
def __init__(self): | ||
pass | ||
|
||
|
||
class SyntheticDataset(Dataset): | ||
def __init__(self, sparse_feature_dim, query_slot_num, title_slot_num): | ||
# ids are randomly generated | ||
self.ids_per_slot = 10 | ||
self.sparse_feature_dim = sparse_feature_dim | ||
self.query_slot_num = query_slot_num | ||
self.title_slot_num = title_slot_num | ||
self.dataset_size = 10000 | ||
|
||
def _reader_creator(self, is_train): | ||
def generate_ids(num, space): | ||
return [random.randint(0, space - 1) for i in range(num)] | ||
|
||
def reader(): | ||
for i in range(self.dataset_size): | ||
query_slots = [] | ||
pos_title_slots = [] | ||
neg_title_slots = [] | ||
for i in range(self.query_slot_num): | ||
qslot = generate_ids(self.ids_per_slot, | ||
self.sparse_feature_dim) | ||
query_slots.append(qslot) | ||
for i in range(self.title_slot_num): | ||
pt_slot = generate_ids(self.ids_per_slot, | ||
self.sparse_feature_dim) | ||
pos_title_slots.append(pt_slot) | ||
if is_train: | ||
for i in range(self.title_slot_num): | ||
nt_slot = generate_ids(self.ids_per_slot, | ||
self.sparse_feature_dim) | ||
neg_title_slots.append(nt_slot) | ||
yield query_slots + pos_title_slots + neg_title_slots | ||
else: | ||
yield query_slots + pos_title_slots | ||
|
||
return reader | ||
|
||
def train(self): | ||
return self._reader_creator(True) | ||
|
||
def valid(self): | ||
return self._reader_creator(True) | ||
|
||
def test(self): | ||
return self._reader_creator(False) |
Oops, something went wrong.