Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: add sequence feature negative sample process #267

Merged
merged 10 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/images/models/uniter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/models/dbmtl.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ DBMTL构建了多个目标之间的贝叶斯网络,显式建模了多个目标

![dbmtl_mmoe.png](../../images/models/dbmtl_mmoe.png)

在多模态(图像、视频、文本)推荐场景,DBMTL支持使用[CMBF模型](cmbf.md)作为底层的`shared layer`,以便充分利用多模态特征,取到更好的推荐效果。
在多模态(图像、视频、文本)推荐场景,DBMTL支持使用[CMBF模型](cmbf.md)或[UNITER模型](uniter.md)作为底层的`shared layer`,以便充分利用多模态特征,取到更好的推荐效果。

### 配置说明

Expand Down
86 changes: 86 additions & 0 deletions docs/source/models/multi_tower_recall.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# MultiTowerRecall

### 简介

专为负采样和序列特征训练准备的双塔召回模型,分为user塔和item塔。
注:使用时需指定user id和item id。

### 配置说明

```protobuf
model_config:{
model_class: "MultiTowerRecall"
feature_groups: {
group_name: 'user'
feature_names: 'user_id'
feature_names: 'cms_segid'
feature_names: 'cms_group_id'
feature_names: 'age_level'
feature_names: 'pvalue_level'
feature_names: 'shopping_level'
feature_names: 'occupation'
feature_names: 'new_user_class_level'
wide_deep:DEEP
negative_sampler:true
sequence_features: {
group_name: "seq_fea"
allow_key_search: true
need_key_feature:true
seq_att_map: {
key: "brand"
key: "cate_id"
hist_seq: "tag_brand_list"
hist_seq: "tag_category_list"
}
}
}
feature_groups: {
group_name: "item"
feature_names: 'adgroup_id'
feature_names: 'cate_id'
feature_names: 'campaign_id'
feature_names: 'customer'
feature_names: 'brand'
wide_deep:DEEP
}
multi_tower_recall {
user_tower {
id: "user_id"
dnn {
hidden_units: [256, 128, 64, 32]
# dropout_ratio : [0.1, 0.1, 0.1, 0.1]
}
}
item_tower {
id: "adgroup_id"
dnn {
hidden_units: [256, 128, 64, 32]
}
}
final_dnn {
hidden_units: [128, 96, 64, 32, 16]
}
l2_regularization: 1e-6
}
loss_type: CLASSIFICATION
embedding_regularization: 5e-6
}
```

- model_class: 'MultiTowerRecall', 不需要修改
- feature_groups: 需要两个feature_group: user和item, **group name不能变**
- multi_tower_recall: multi_tower_recall相关的参数,必须配置user_tower和item_tower
- user_tower/item_tower:
- dnn: deep part的参数配置
- hidden_units: dnn每一层的channel数目,即神经元的数目
- embedding_regularization: 对embedding部分加regularization,防止overfit

支持的metric_set包括:

- auc
- mean_absolute_error
- accuracy

### 示例Config

见路径:samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config
1 change: 1 addition & 0 deletions docs/source/models/rank.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
:maxdepth: 2

cmbf
uniter
11 changes: 7 additions & 4 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __call__(self, features, group_name, is_combine=True):
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
group_name, ','.join([x for x in self._feature_groups]))
feature_name_to_output_tensors = {}
negative_sampler = self._feature_groups[group_name]._config.negative_sampler
if group_name in self._group_name_to_seq_features:
for seq_feature in self._group_name_to_seq_features[group_name]:
for seq_att in seq_feature.seq_att_map:
Expand All @@ -93,10 +94,12 @@ def __call__(self, features, group_name, is_combine=True):
concat_features, group_features = self.single_call_input_layer(
features, group_name, is_combine, feature_name_to_output_tensors)
if group_name in self._group_name_to_seq_features:
seq_fea = self.sequence_feature_layer(
features, self._group_name_to_seq_features[group_name],
feature_name_to_output_tensors)
concat_features = tf.concat([concat_features, seq_fea], axis=1)
concat_features = self.sequence_feature_layer(
features,
concat_features,
self._group_name_to_seq_features[group_name],
feature_name_to_output_tensors,
negative_sampler=negative_sampler)
return concat_features, group_features
else:
if self._variational_dropout_config is not None:
Expand Down
98 changes: 89 additions & 9 deletions easy_rec/python/layers/sequence_feature_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,74 @@ def __init__(self,
self._embedding_regularizer = embedding_regularizer
self._is_training = is_training

def negative_sampler_target_attention(self,
dnn_config,
deep_fea,
concat_features,
name,
need_key_feature=True,
allow_key_transform=False):
cur_id, hist_id_col, seq_len, aux_hist_emb_list = deep_fea['key'], deep_fea[
'hist_seq_emb'], deep_fea['hist_seq_len'], deep_fea[
'aux_hist_seq_emb_list']

seq_max_len = tf.shape(hist_id_col)[1]
seq_emb_dim = hist_id_col.shape[2]
cur_id_dim = tf.shape(cur_id)[-1]
batch_size = tf.shape(hist_id_col)[0]

pos_feature = cur_id[:batch_size]
neg_feature = cur_id[batch_size:]
cur_id = tf.concat([
pos_feature[:, tf.newaxis, :],
tf.tile(neg_feature[tf.newaxis, :, :], multiples=[batch_size, 1, 1])
],
axis=1) # noqa: E126
neg_num = tf.shape(cur_id)[1]
hist_id_col = tf.tile(hist_id_col[:, :, :], multiples=[neg_num, 1, 1])
concat_features = tf.tile(
concat_features[:, tf.newaxis, :], multiples=[1, neg_num, 1])
seq_len = tf.tile(seq_len, multiples=[neg_num])

if allow_key_transform and (cur_id_dim != seq_emb_dim):
cur_id = tf.layers.dense(
cur_id, seq_emb_dim, name='sequence_key_transform_layer')

cur_ids = tf.tile(cur_id, [1, 1, seq_max_len])
cur_ids = tf.reshape(
cur_ids,
tf.shape(hist_id_col)) # (B, neg_num, seq_max_len, seq_emb_dim)

din_net = tf.concat(
[cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
axis=-1) # (B, seq_max_len, seq_emb_dim*4)

din_layer = dnn.DNN(dnn_config, None, name, self._is_training)
din_net = din_layer(din_net)
scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)

seq_len = tf.expand_dims(seq_len, 1)
mask = tf.sequence_mask(seq_len)
padding = tf.ones_like(scores) * (-2**32 + 1)
scores = tf.where(mask, scores, padding) # [B*neg_num, 1, seq_max_len]

# Scale
scores = tf.nn.softmax(scores) # (B, 1, seq_max_len)
hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, seq_emb_dim]
hist_din_emb = tf.reshape(
hist_din_emb, [batch_size, neg_num, seq_emb_dim]) # [B, seq_emb_dim]
if len(aux_hist_emb_list) > 0:
all_hist_dim_emb = [hist_din_emb]
for hist_col in aux_hist_emb_list:
cur_aux_hist = tf.matmul(scores, hist_col)
outputs = tf.reshape(cur_aux_hist, [-1, seq_emb_dim])
all_hist_dim_emb.append(outputs)
hist_din_emb = tf.concat(all_hist_dim_emb, axis=1)
if not need_key_feature:
return hist_din_emb
din_output = tf.concat([hist_din_emb, cur_id], axis=2)
return din_output, concat_features

def target_attention(self,
dnn_config,
deep_fea,
Expand Down Expand Up @@ -85,8 +153,10 @@ def target_attention(self,

def __call__(self,
features,
concat_features,
all_seq_att_map_config,
feature_name_to_output_tensors=None):
feature_name_to_output_tensors=None,
negative_sampler=False):
logging.info('use sequence feature layer.')
all_seq_fea = []
# process all sequence features
Expand Down Expand Up @@ -114,13 +184,23 @@ def __call__(self,
seq_dnn_config = DNN()
seq_dnn_config.hidden_units.extend([128, 64, 32, 1])
cur_target_attention_name = 'seq_dnn' + group_name
seq_fea = self.target_attention(
seq_dnn_config,
seq_features,
name=cur_target_attention_name,
need_key_feature=need_key_feature,
allow_key_transform=allow_key_transform)
if negative_sampler:
seq_fea, concat_features = self.negative_sampler_target_attention(
seq_dnn_config,
seq_features,
concat_features,
name=cur_target_attention_name,
need_key_feature=need_key_feature,
allow_key_transform=allow_key_transform)
else:
seq_fea = self.target_attention(
seq_dnn_config,
seq_features,
name=cur_target_attention_name,
need_key_feature=need_key_feature,
allow_key_transform=allow_key_transform)
all_seq_fea.append(seq_fea)
# concat all seq_fea
all_seq_fea = tf.concat(all_seq_fea, axis=1)
return all_seq_fea
all_seq_fea = tf.concat(all_seq_fea, axis=-1)
concat_features = tf.concat([concat_features, all_seq_fea], axis=-1)
return concat_features
68 changes: 68 additions & 0 deletions easy_rec/python/model/multi_tower_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import tensorflow as tf

from easy_rec.python.layers import dnn
from easy_rec.python.model.rank_model import RankModel

from easy_rec.python.protos.multi_tower_recall_pb2 import MultiTowerRecall as MultiTowerRecallConfig # NOQA

if tf.__version__ >= '2.0':
tf = tf.compat.v1


class MultiTowerRecall(RankModel):

def __init__(self,
model_config,
feature_configs,
features,
labels=None,
is_training=False):
super(MultiTowerRecall, self).__init__(model_config, feature_configs,
features, labels, is_training)
assert self._model_config.WhichOneof('model') == 'multi_tower_recall', (
'invalid model config: %s' % self._model_config.WhichOneof('model'))
self._model_config = self._model_config.multi_tower_recall
assert isinstance(self._model_config, MultiTowerRecallConfig)

self.user_tower_feature, _ = self._input_layer(self._feature_dict, 'user')
self.item_tower_feature, _ = self._input_layer(self._feature_dict, 'item')

def build_predict_graph(self):

user_tower_feature = self.user_tower_feature
batch_size = tf.shape(user_tower_feature)[0]
pos_item_feature = self.item_tower_feature[:batch_size]
neg_item_feature = self.item_tower_feature[batch_size:]
item_tower_feature = tf.concat([
pos_item_feature[:, tf.newaxis, :],
tf.tile(
neg_item_feature[tf.newaxis, :, :], multiples=[batch_size, 1, 1])
],
axis=1) # noqa: E126

user_dnn = dnn.DNN(self._model_config.user_tower.dnn, self._l2_reg,
'user_dnn', self._is_training)
user_tower_emb = user_dnn(user_tower_feature)

item_dnn = dnn.DNN(self._model_config.item_tower.dnn, self._l2_reg,
'item_dnn', self._is_training)
item_tower_emb = item_dnn(item_tower_feature)
item_tower_emb = tf.reshape(item_tower_emb, tf.shape(user_tower_emb))

tower_fea_arr = []
tower_fea_arr.append(user_tower_emb)
tower_fea_arr.append(item_tower_emb)

all_fea = tf.concat(tower_fea_arr, axis=-1)
final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
'final_dnn', self._is_training)
all_fea = final_dnn_layer(all_fea)
output = tf.layers.dense(all_fea, 1, name='output')
output = output[:, 0]

self._add_to_prediction_dict(output)

return self._prediction_dict
2 changes: 2 additions & 0 deletions easy_rec/python/protos/easy_rec_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import "easy_rec/python/protos/mind.proto";
import "easy_rec/python/protos/loss.proto";
import "easy_rec/python/protos/rocket_launching.proto";
import "easy_rec/python/protos/variational_dropout.proto";
import "easy_rec/python/protos/multi_tower_recall.proto";
// for input performance test
message DummyModel {

Expand Down Expand Up @@ -65,6 +66,7 @@ message EasyRecModel {
CMBF cmbf = 109;
Uniter uniter = 110;

MultiTowerRecall multi_tower_recall = 200;
DSSM dssm = 201;
MIND mind = 202;
DropoutNet dropoutnet = 203;
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ message FeatureGroupConfig {

optional WideOrDeep wide_deep = 3 [default = DEEP];
repeated SeqAttGroupConfig sequence_features = 4;
optional bool negative_sampler = 5 [default = false];
}

message SeqAttMap {
Expand Down
19 changes: 19 additions & 0 deletions easy_rec/python/protos/multi_tower_recall.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
syntax = "proto2";
package protos;

import "easy_rec/python/protos/dnn.proto";
import "easy_rec/python/protos/simi.proto";


message RecallTower {
required DNN dnn = 1;
};


message MultiTowerRecall {
required RecallTower user_tower = 1;
required RecallTower item_tower = 2;
required float l2_regularization = 3 [default = 1e-4];
required DNN final_dnn = 4;
required bool ignore_in_batch_neg_sam = 10 [default = false];
}
9 changes: 8 additions & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import logging
import os
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down Expand Up @@ -905,6 +905,13 @@ def test_deepfm_on_sequence_feature_aux_hist_seq(self):
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_multi_tower_recall_neg_sampler_sequence_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config',
self._test_dir)
self.assertTrue(self._success)


if __name__ == '__main__':
tf.test.main()
2 changes: 1 addition & 1 deletion easy_rec/python/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import os
import re
import time
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import meta_graph
from tensorflow.python.training.summary_io import SummaryWriterCache
Expand Down
Loading