Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]:support only sequence feature && fix neg sampler bug for sequence feature #264

Merged
merged 3 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/feature/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ Sequense类特征格式一般为“XX\|XX\|XX”,如用户行为序列特征
sequence_features: {
group_name: "seq_fea"
allow_key_search: true
need_key_feature:true
seq_att_map: {
key: "brand"
key: "cate_id"
Expand All @@ -281,6 +282,8 @@ Sequense类特征格式一般为“XX\|XX\|XX”,如用户行为序列特征

- sequence_features: 序列特征组的名称
- allow_key_search: 当 key 对应的特征没有在 feature_groups 里面时,需要设置为 true, 将会复用对应特征的 embedding.
- need_key_feature : 默认为 true, 指过完 target attention 之后的特征会和 key 对应的特征 concat 之后返回。
设置为 false 时,将会只返回过完 target attention 之后的特征。
- seq_att_map: 具体细节可以参考排序里的 DIN 模型。
- NOTE:SequenceFeature一般放在 user 组里面。

Expand Down
11 changes: 9 additions & 2 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ def __init__(self,
def has_group(self, group_name):
return group_name in self._feature_groups

def target_attention(self, dnn_config, deep_fea, name):
def target_attention(self, dnn_config, deep_fea, name, need_key_feature=True):
cur_id, hist_id_col, seq_len = deep_fea['key'], deep_fea[
'hist_seq_emb'], deep_fea['hist_seq_len']

seq_max_len = tf.shape(hist_id_col)[1]
emb_dim = hist_id_col.shape[2]

cur_id = cur_id[:tf.shape(hist_id_col)[0], ...] # for negative sampler
cur_ids = tf.tile(cur_id, [1, seq_max_len])
cur_ids = tf.reshape(cur_ids,
tf.shape(hist_id_col)) # (B, seq_max_len, emb_dim)
Expand All @@ -96,6 +97,8 @@ def target_attention(self, dnn_config, deep_fea, name):
scores = tf.nn.softmax(scores) # (B, 1, seq_max_len)
hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, emb_dim]
hist_din_emb = tf.reshape(hist_din_emb, [-1, emb_dim]) # [B, emb_dim]
if not need_key_feature:
return hist_din_emb
din_output = tf.concat([hist_din_emb, cur_id], axis=1)
return din_output

Expand All @@ -108,6 +111,7 @@ def call_seq_input_layer(self,
for seq_att_map_config in all_seq_att_map_config:
group_name = seq_att_map_config.group_name
allow_key_search = seq_att_map_config.allow_key_search
need_key_feature = seq_att_map_config.need_key_feature
seq_features = self._seq_input_layer(features, group_name,
feature_name_to_output_tensors,
allow_key_search)
Expand All @@ -128,7 +132,10 @@ def call_seq_input_layer(self,
seq_dnn_config.hidden_units.extend([128, 64, 32, 1])
cur_target_attention_name = 'seq_dnn' + group_name
seq_fea = self.target_attention(
seq_dnn_config, seq_features, name=cur_target_attention_name)
seq_dnn_config,
seq_features,
name=cur_target_attention_name,
need_key_feature=need_key_feature)
all_seq_fea.append(seq_fea)
# concat all seq_fea
all_seq_fea = tf.concat(all_seq_fea, axis=1)
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,5 @@ message SeqAttGroupConfig {
optional bool tf_summary = 3 [default = false];
optional DNN seq_dnn = 4;
optional bool allow_key_search = 5 [default = false];
optional bool need_key_feature = 6 [default = true];
}
20 changes: 20 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,26 @@ def test_distribute_eval_esmm(self):
cur_eval_path, self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_neg_sampler_sequence_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_sequence_feature.config',
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dssm_neg_sampler_need_key_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dssm_neg_sampler_need_key_feature.config',
self._test_dir)
self.assertTrue(self._success)

def test_dbmtl_on_multi_numeric_boundary_need_key_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dbmtl_on_multi_numeric_boundary_need_key_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)


if __name__ == '__main__':
tf.test.main()
Loading