-
Notifications
You must be signed in to change notification settings - Fork 343
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat]: add sequence feature negative sample process (#267)
* add negative sample process * add multi_tower_recall docs
- Loading branch information
Showing
13 changed files
with
587 additions
and
16 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# MultiTowerRecall | ||
|
||
### 简介 | ||
|
||
专为负采样和序列特征训练准备的双塔召回模型,分为user塔和item塔。 | ||
注:使用时需指定user id和item id。 | ||
|
||
### 配置说明 | ||
|
||
```protobuf | ||
model_config:{ | ||
model_class: "MultiTowerRecall" | ||
feature_groups: { | ||
group_name: 'user' | ||
feature_names: 'user_id' | ||
feature_names: 'cms_segid' | ||
feature_names: 'cms_group_id' | ||
feature_names: 'age_level' | ||
feature_names: 'pvalue_level' | ||
feature_names: 'shopping_level' | ||
feature_names: 'occupation' | ||
feature_names: 'new_user_class_level' | ||
wide_deep:DEEP | ||
negative_sampler:true | ||
sequence_features: { | ||
group_name: "seq_fea" | ||
allow_key_search: true | ||
need_key_feature:true | ||
seq_att_map: { | ||
key: "brand" | ||
key: "cate_id" | ||
hist_seq: "tag_brand_list" | ||
hist_seq: "tag_category_list" | ||
} | ||
} | ||
} | ||
feature_groups: { | ||
group_name: "item" | ||
feature_names: 'adgroup_id' | ||
feature_names: 'cate_id' | ||
feature_names: 'campaign_id' | ||
feature_names: 'customer' | ||
feature_names: 'brand' | ||
wide_deep:DEEP | ||
} | ||
multi_tower_recall { | ||
user_tower { | ||
id: "user_id" | ||
dnn { | ||
hidden_units: [256, 128, 64, 32] | ||
# dropout_ratio : [0.1, 0.1, 0.1, 0.1] | ||
} | ||
} | ||
item_tower { | ||
id: "adgroup_id" | ||
dnn { | ||
hidden_units: [256, 128, 64, 32] | ||
} | ||
} | ||
final_dnn { | ||
hidden_units: [128, 96, 64, 32, 16] | ||
} | ||
l2_regularization: 1e-6 | ||
} | ||
loss_type: CLASSIFICATION | ||
embedding_regularization: 5e-6 | ||
} | ||
``` | ||
|
||
- model_class: 'MultiTowerRecall', 不需要修改 | ||
- feature_groups: 需要两个feature_group: user和item, **group name不能变** | ||
- multi_tower_recall: multi_tower_recall相关的参数,必须配置user_tower和item_tower | ||
- user_tower/item_tower: | ||
- dnn: deep part的参数配置 | ||
- hidden_units: dnn每一层的channel数目,即神经元的数目 | ||
- embedding_regularization: 对embedding部分加regularization,防止overfit | ||
|
||
支持的metric_set包括: | ||
|
||
- auc | ||
- mean_absolute_error | ||
- accuracy | ||
|
||
### 示例Config | ||
|
||
见路径:samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,3 +24,4 @@ | |
:maxdepth: 2 | ||
|
||
cmbf | ||
uniter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
import tensorflow as tf | ||
|
||
from easy_rec.python.layers import dnn | ||
from easy_rec.python.model.rank_model import RankModel | ||
|
||
from easy_rec.python.protos.multi_tower_recall_pb2 import MultiTowerRecall as MultiTowerRecallConfig # NOQA | ||
|
||
if tf.__version__ >= '2.0': | ||
tf = tf.compat.v1 | ||
|
||
|
||
class MultiTowerRecall(RankModel): | ||
|
||
def __init__(self, | ||
model_config, | ||
feature_configs, | ||
features, | ||
labels=None, | ||
is_training=False): | ||
super(MultiTowerRecall, self).__init__(model_config, feature_configs, | ||
features, labels, is_training) | ||
assert self._model_config.WhichOneof('model') == 'multi_tower_recall', ( | ||
'invalid model config: %s' % self._model_config.WhichOneof('model')) | ||
self._model_config = self._model_config.multi_tower_recall | ||
assert isinstance(self._model_config, MultiTowerRecallConfig) | ||
|
||
self.user_tower_feature, _ = self._input_layer(self._feature_dict, 'user') | ||
self.item_tower_feature, _ = self._input_layer(self._feature_dict, 'item') | ||
|
||
def build_predict_graph(self): | ||
|
||
user_tower_feature = self.user_tower_feature | ||
batch_size = tf.shape(user_tower_feature)[0] | ||
pos_item_feature = self.item_tower_feature[:batch_size] | ||
neg_item_feature = self.item_tower_feature[batch_size:] | ||
item_tower_feature = tf.concat([ | ||
pos_item_feature[:, tf.newaxis, :], | ||
tf.tile( | ||
neg_item_feature[tf.newaxis, :, :], multiples=[batch_size, 1, 1]) | ||
], | ||
axis=1) # noqa: E126 | ||
|
||
user_dnn = dnn.DNN(self._model_config.user_tower.dnn, self._l2_reg, | ||
'user_dnn', self._is_training) | ||
user_tower_emb = user_dnn(user_tower_feature) | ||
|
||
item_dnn = dnn.DNN(self._model_config.item_tower.dnn, self._l2_reg, | ||
'item_dnn', self._is_training) | ||
item_tower_emb = item_dnn(item_tower_feature) | ||
item_tower_emb = tf.reshape(item_tower_emb, tf.shape(user_tower_emb)) | ||
|
||
tower_fea_arr = [] | ||
tower_fea_arr.append(user_tower_emb) | ||
tower_fea_arr.append(item_tower_emb) | ||
|
||
all_fea = tf.concat(tower_fea_arr, axis=-1) | ||
final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, | ||
'final_dnn', self._is_training) | ||
all_fea = final_dnn_layer(all_fea) | ||
output = tf.layers.dense(all_fea, 1, name='output') | ||
output = output[:, 0] | ||
|
||
self._add_to_prediction_dict(output) | ||
|
||
return self._prediction_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
syntax = "proto2"; | ||
package protos; | ||
|
||
import "easy_rec/python/protos/dnn.proto"; | ||
import "easy_rec/python/protos/simi.proto"; | ||
|
||
|
||
message RecallTower { | ||
required DNN dnn = 1; | ||
}; | ||
|
||
|
||
message MultiTowerRecall { | ||
required RecallTower user_tower = 1; | ||
required RecallTower item_tower = 2; | ||
required float l2_regularization = 3 [default = 1e-4]; | ||
required DNN final_dnn = 4; | ||
required bool ignore_in_batch_neg_sam = 10 [default = false]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.