Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]Support gnn on datascience #229

Merged
merged 15 commits into from
Jul 18, 2022
Merged
1 change: 1 addition & 0 deletions .git_bin_path
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
{"leaf_name": "data/test/inference/tb_multitower_rtp_export/variables", "leaf_file": ["data/test/inference/tb_multitower_rtp_export/variables/variables.data-00000-of-00001", "data/test/inference/tb_multitower_rtp_export/variables/variables.index"]}
{"leaf_name": "data/test/latest_ckpt_test", "leaf_file": ["data/test/latest_ckpt_test/model.ckpt-500.data-00000-of-00001", "data/test/latest_ckpt_test/model.ckpt-500.index", "data/test/latest_ckpt_test/model.ckpt-500.meta"]}
{"leaf_name": "data/test/movielens_1m", "leaf_file": ["data/test/movielens_1m/ml_test_data", "data/test/movielens_1m/ml_train_data"]}
{"leaf_name": "data/test/mt_ckpt", "leaf_file": ["data/test/mt_ckpt/model.ckpt-100.data-00000-of-00001", "data/test/mt_ckpt/model.ckpt-100.index", "data/test/mt_ckpt/model.ckpt-100.meta"]}
{"leaf_name": "data/test/rtp", "leaf_file": ["data/test/rtp/taobao_fg_pred.out", "data/test/rtp/taobao_test_bucketize_feature.txt", "data/test/rtp/taobao_test_feature.txt", "data/test/rtp/taobao_test_input.txt", "data/test/rtp/taobao_train_bucketize_feature.txt", "data/test/rtp/taobao_train_feature.txt", "data/test/rtp/taobao_train_input.txt", "data/test/rtp/taobao_valid.csv", "data/test/rtp/taobao_valid_feature.txt"]}
{"leaf_name": "data/test/tb_data", "leaf_file": ["data/test/tb_data/taobao_ad_feature_gl", "data/test/tb_data/taobao_clk_edge_gl", "data/test/tb_data/taobao_multi_seq_test_data", "data/test/tb_data/taobao_multi_seq_train_data", "data/test/tb_data/taobao_noclk_edge_gl", "data/test/tb_data/taobao_test_data", "data/test/tb_data/taobao_test_data_compress.gz", "data/test/tb_data/taobao_test_data_for_expr", "data/test/tb_data/taobao_test_data_kd", "data/test/tb_data/taobao_train_data", "data/test/tb_data/taobao_train_data_for_expr", "data/test/tb_data/taobao_train_data_kd", "data/test/tb_data/taobao_user_profile_gl"]}
{"leaf_name": "data/test/tb_data_with_time", "leaf_file": ["data/test/tb_data_with_time/taobao_test_data_with_time", "data/test/tb_data_with_time/taobao_train_data_with_time"]}
1 change: 1 addition & 0 deletions .git_bin_url
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
{"leaf_path": "data/test/inference/tb_multitower_rtp_export/variables", "sig": "efe52ef308fd6452f3b67fd04cdd22bd", "remote_path": "data/git_oss_sample_data/data_test_inference_tb_multitower_rtp_export_variables_efe52ef308fd6452f3b67fd04cdd22bd"}
{"leaf_path": "data/test/latest_ckpt_test", "sig": "d41d8cd98f00b204e9800998ecf8427e", "remote_path": "data/git_oss_sample_data/data_test_latest_ckpt_test_d41d8cd98f00b204e9800998ecf8427e"}
{"leaf_path": "data/test/movielens_1m", "sig": "99badbeec64f2fcabe0dfa1d2bfd8fb5", "remote_path": "data/git_oss_sample_data/data_test_movielens_1m_99badbeec64f2fcabe0dfa1d2bfd8fb5"}
{"leaf_path": "data/test/mt_ckpt", "sig": "803499f48e2df5e51ce5606e9649c6d4", "remote_path": "data/git_oss_sample_data/data_test_mt_ckpt_803499f48e2df5e51ce5606e9649c6d4"}
{"leaf_path": "data/test/rtp", "sig": "76cda60582617ddbb7cd5a49eb68a4b9", "remote_path": "data/git_oss_sample_data/data_test_rtp_76cda60582617ddbb7cd5a49eb68a4b9"}
{"leaf_path": "data/test/tb_data", "sig": "126c375d6aa666633fb3084aa27ff9f7", "remote_path": "data/git_oss_sample_data/data_test_tb_data_126c375d6aa666633fb3084aa27ff9f7"}
{"leaf_path": "data/test/tb_data_with_time", "sig": "1a7648f4ae55faf37855762bccbb70cc", "remote_path": "data/git_oss_sample_data/data_test_tb_data_with_time_1a7648f4ae55faf37855762bccbb70cc"}
6 changes: 3 additions & 3 deletions docs/source/models/cmbf.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ Cross-Modal-Based Fusion Recommendation Algorithm(CMBF)是一个能够捕获
CMBF主要有4个模块(如上图):

1. 预处理模块:提取图片和文本特征
1. 单模态学习模块:基于Transformer学习图像、文本的语义特征
1. 跨模态融合模块:学习两个模态之间的交叉特性
1. 输出模块:获取高阶特征并预测结果
2. 单模态学习模块:基于Transformer学习图像、文本的语义特征
3. 跨模态融合模块:学习两个模态之间的交叉特性
4. 输出模块:获取高阶特征并预测结果

模型支持四种类型的特征组(`feature group`),如下所述。
不一定需要有全部四种类型的输入特征,只需要保证至少有一种类型的输入特征即可训练模型。根据输入特征类型的不同,部分网络结构可能会被`短路`(skip)掉。
Expand Down
56 changes: 48 additions & 8 deletions easy_rec/python/core/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tensorflow as tf

from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.utils import ds_util
from easy_rec.python.utils.tf_utils import get_tf_type

try:
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(self, fields, num_sample, num_eval_sample=None):
self._num_eval_sample = num_eval_sample if num_eval_sample is not None else num_sample
self._build_field_types(fields)
self._log_first_n = 5
self._is_on_ds = ds_util.is_on_ds()

def set_eval_num_sample(self):
print('set_eval_num_sample: %d %d' %
Expand All @@ -75,9 +77,20 @@ def _init_graph(self):
if 'ps' in tf_config['cluster']:
# ps mode
tf_config = json.loads(os.environ['TF_CONFIG'])
ps_count = len(tf_config['cluster']['ps'])
task_count = len(tf_config['cluster']['worker']) + 2
cluster = {'server_count': ps_count, 'client_count': task_count}
if self._is_on_ds:
gl.set_tracker_mode(0)
server_hosts = [
host.split(':')[0] + ':888' + str(i)
for i, host in enumerate(tf_config['cluster']['ps'])
]
cluster = {
'server': ','.join(server_hosts),
'client_count': task_count
}
else:
ps_count = len(tf_config['cluster']['ps'])
cluster = {'server_count': ps_count, 'client_count': task_count}
if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(cluster=cluster, job_name='client', task_index=0)
elif tf_config['task']['type'] == 'worker':
Expand All @@ -101,11 +114,35 @@ def _init_graph(self):
else:
# worker mode
task_count = len(tf_config['cluster']['worker']) + 1
if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(task_index=0, task_count=task_count)
elif tf_config['task']['type'] == 'worker':
self._g.init(
task_index=tf_config['task']['index'] + 1, task_count=task_count)
if not self._is_on_ds:
if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(task_index=0, task_count=task_count)
elif tf_config['task']['type'] == 'worker':
self._g.init(
task_index=tf_config['task']['index'] + 1,
task_count=task_count)
else:
gl.set_tracker_mode(0)
if tf_config['cluster'].get('chief', ''):
chief_host = tf_config['cluster']['chief'][0].split(
':')[0] + ':8880'
else:
chief_host = tf_config['cluster']['master'][0].split(
':')[0] + ':8880'
worker_hosts = chief_host + \
[host.split(':')[0] + ':888' + str(i) for i, host in enumerate(tf_config['cluster']['worker'])]

if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(
task_index=0,
task_count=task_count,
hosts=','.join(worker_hosts))
elif tf_config['task']['type'] == 'worker':
self._g.init(
task_index=tf_config['task']['index'] + 1,
task_count=task_count,
hosts=worker_hosts)

# TODO(hongsheng.jhs): check cluster has evaluator or not?
else:
# local mode
Expand Down Expand Up @@ -276,7 +313,6 @@ def __init__(self,
self._item_ids = []
self._cols = [[] for x in fields]

# try load from odps table
if six.PY2 and isinstance(attr_delimiter, type(u'')):
attr_delimiter = attr_delimiter.encode('utf-8')
if data_path.startswith('odps://'):
Expand Down Expand Up @@ -674,9 +710,13 @@ def build(data_config):
sampler_type = data_config.WhichOneof('sampler')
print('sampler_type = %s' % sampler_type)
sampler_config = getattr(data_config, sampler_type)
if ds_util.is_on_ds():
gl.set_field_delimiter(sampler_config.field_delimiter)

if sampler_type == 'negative_sampler':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]

return NegativeSampler.instance(
data_path=sampler_config.input_path,
fields=attr_fields,
Expand Down
7 changes: 5 additions & 2 deletions easy_rec/python/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from easy_rec.python.main import distribute_evaluate
from easy_rec.python.main import evaluate
from easy_rec.python.utils import ds_util

from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_distribute_eval_worker_num_on_ds # NOQA
if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -41,8 +42,10 @@ def main(argv):
if FLAGS.odps_config:
os.environ['ODPS_CONFIG_FILE_PATH'] = FLAGS.odps_config

if FLAGS.is_on_ds and FLAGS.distribute_eval:
set_tf_config_and_get_distribute_eval_worker_num_on_ds()
if FLAGS.is_on_ds:
ds_util.set_on_ds()
if FLAGS.distribute_eval:
set_tf_config_and_get_distribute_eval_worker_num_on_ds()

assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
if FLAGS.model_dir:
Expand Down
6 changes: 3 additions & 3 deletions easy_rec/python/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from easy_rec.python.utils.check_utils import check_split
from easy_rec.python.utils.config_util import get_configs_from_pipeline_file
from easy_rec.python.utils.config_util import get_input_name_from_fg_json
from easy_rec.python.utils.config_util import search_fg_json
from easy_rec.python.utils.hive_utils import HiveUtils
from easy_rec.python.utils.input_utils import get_type_defaults
from easy_rec.python.utils.load_class import get_register_class_meta
Expand Down Expand Up @@ -404,9 +405,8 @@ def _get_fg_json(self, fg_json_path, model_path):
with tf.gfile.GFile(fg_json_path, 'r') as fin:
fg_json = json.loads(fin.read())
else:
fg_json_path = os.path.join(model_path, 'assets/fg.json')
if gfile.Exists(fg_json_path):
logging.info('load fg_json_path: ', fg_json_path)
fg_json_path = search_fg_json(model_path)
if fg_json_path:
with tf.gfile.GFile(fg_json_path, 'r') as fin:
fg_json = json.loads(fin.read())
else:
Expand Down
11 changes: 7 additions & 4 deletions easy_rec/python/input/hive_rtp_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,27 @@ def __init__(self,
def _parse_csv(self, line):
record_defaults = []
for tid, field_name in enumerate(self._input_table_col_names):
if field_name in self._selected_cols:
if field_name in self._selected_cols[:-1]:
idx = self._input_fields.index(field_name)
record_defaults.append(
self.get_type_defaults(self._input_field_types[tid],
self._input_field_defaults[tid]))
self.get_type_defaults(self._input_field_types[idx],
self._input_field_defaults[idx]))
else:
record_defaults.append('')

print('record_defaults: ', record_defaults)
tmp_fields = tf.decode_csv(
line,
field_delim=self._rtp_separator,
record_defaults=record_defaults,
name='decode_csv')
print('tmp_fields: ', tmp_fields)

fields = []
if self._selected_cols:
for idx, field_name in enumerate(self._input_table_col_names):
if field_name in self._selected_cols:
fields.append(tmp_fields[idx])
print('fields: ', fields)
labels = fields[:-1]

# only for features, labels and sample_weight excluded
Expand Down
15 changes: 15 additions & 0 deletions easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ message NegativeSampler {
optional string attr_delimiter = 5 [default=":"];

optional uint32 num_eval_sample = 6 [default=0];

// only works on DataScience/Local
optional string field_delimiter = 7 [default="\001"];
}

message NegativeSamplerInMemory {
Expand All @@ -32,6 +35,9 @@ message NegativeSamplerInMemory {
optional string attr_delimiter = 5 [default=":"];

optional uint32 num_eval_sample = 6 [default=0];

// only works on DataScience/Local
optional string field_delimiter = 7 [default="\001"];
}

// Weighted Random Sampling ItemID not with Edge
Expand All @@ -57,6 +63,9 @@ message NegativeSamplerV2 {
optional string attr_delimiter = 8 [default=":"];

optional uint32 num_eval_sample = 9 [default=0];

// only works on DataScience/Local
optional string field_delimiter = 10 [default="\001"];
}

// Weighted Random Sampling ItemID not in Batch and Sampling Hard Edge
Expand Down Expand Up @@ -84,6 +93,9 @@ message HardNegativeSampler {
optional string attr_delimiter = 9 [default=":"];

optional uint32 num_eval_sample = 10 [default=0];

// only works on DataScience/Local
optional string field_delimiter = 11 [default="\001"];
}

// Weighted Random Sampling ItemID not with Edge and Sampling Hard Edge
Expand Down Expand Up @@ -114,6 +126,9 @@ message HardNegativeSamplerV2 {
optional string attr_delimiter = 10 [default=":"];

optional uint32 num_eval_sample = 11 [default=0];

// only works on DataScience/Local
optional string field_delimiter = 12 [default="\001"];
}

message DatasetConfig {
Expand Down
18 changes: 18 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,24 @@ def _post_check_func(pipeline_config):
post_check_func=_post_check_func)
self.assertTrue(self._success)

def test_fine_tune_latest_ckpt_path(self):

def _post_check_func(pipeline_config):
logging.info('model_dir: %s' % pipeline_config.model_dir)
pipeline_config = config_util.get_configs_from_pipeline_file(
os.path.join(pipeline_config.model_dir, 'pipeline.config'), False)
logging.info('fine_tune_checkpoint: %s' %
pipeline_config.train_config.fine_tune_checkpoint)
return pipeline_config.train_config.fine_tune_checkpoint == \
'data/test/mt_ckpt/model.ckpt-100'

self._success = test_utils.test_single_train_eval(
'samples/model_config/multi_tower_on_taobao.config',
self._test_dir,
fine_tune_checkpoint='data/test/mt_ckpt',
post_check_func=_post_check_func)
self.assertTrue(self._success)

def test_fine_tune_ckpt(self):

def _post_check_func(pipeline_config):
Expand Down
48 changes: 12 additions & 36 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Split DSSM and Mind saved_model into user part and item part.
import copy
import logging
import os
Expand All @@ -18,6 +17,8 @@
tf.app.flags.DEFINE_string('model_dir', '', '')
tf.app.flags.DEFINE_string('user_model_dir', '', '')
tf.app.flags.DEFINE_string('item_model_dir', '', '')
tf.app.flags.DEFINE_string('user_fg_json_path', '', '')
tf.app.flags.DEFINE_string('item_fg_json_path', '', '')

logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
Expand Down Expand Up @@ -93,35 +94,6 @@ def extract_sub_graph(graph_def, dest_nodes, variable_protos):
if n in edges:
nodes_to_keep.add(n)
next_to_visit += edges[n]

init_all_tables = []
if 'init_all_tables' in edges:
for init in edges['init_all_tables']:
init = init.strip()
sufix = '/table_init'
table_name = init[:-len(sufix)]
if table_name in nodes_to_keep:
init_all_tables.append(init)

next_to_visit = list(init_all_tables)
while next_to_visit:
n = next_to_visit[0]

if n in variable_protos:
proto = variable_protos[n]
next_to_visit.append(_node_name(proto.initial_value_name))
next_to_visit.append(_node_name(proto.initializer_name))
next_to_visit.append(_node_name(proto.snapshot_name))
variables_to_keep.add(proto.variable_name)

del next_to_visit[0]
if n in nodes_to_keep:
continue
# make sure n is in edges
if n in edges:
nodes_to_keep.add(n)
next_to_visit += edges[n]

nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])

out = graph_pb2.GraphDef()
Expand All @@ -130,7 +102,7 @@ def extract_sub_graph(graph_def, dest_nodes, variable_protos):
out.library.CopyFrom(graph_def.library)
out.versions.CopyFrom(graph_def.versions)

return out, variables_to_keep, init_all_tables
return out, variables_to_keep


def load_meta_graph_def(model_dir):
Expand Down Expand Up @@ -217,7 +189,7 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
_node_name(output_tensor_names[x]) for x in output_tensor_names.keys()
]

inference_graph, variables_to_keep, init_op_names = extract_sub_graph(
inference_graph, variables_to_keep = extract_sub_graph(
meta_graph_def.graph_def, output_node_names, variable_protos)

tf.reset_default_graph()
Expand Down Expand Up @@ -254,21 +226,25 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
))

main_ops = [graph.get_operation_by_name(x) for x in init_op_names]
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
prediction_signature,
},
main_op=tf.group(main_ops) if len(main_ops) > 0 else None)
})
builder.save()
config_path = os.path.join(model_dir, 'assets/pipeline.config')
assert tf.gfile.Exists(config_path)
dst_path = os.path.join(part_dir, 'assets')
dst_config_path = os.path.join(part_dir, 'assets/pipeline.config')
dst_config_path = os.path.join(dst_path, 'pipeline.config')
tf.gfile.MkDir(dst_path)
tf.gfile.Copy(config_path, dst_config_path)
if part_name == 'user' and FLAGS.user_fg_json_path:
dst_fg_path = os.path.join(dst_path, 'fg.json')
tf.gfile.Copy(FLAGS.user_fg_json_path, dst_fg_path)
if part_name == 'item' and FLAGS.item_fg_json_path:
dst_fg_path = os.path.join(dst_path, 'fg.json')
tf.gfile.Copy(FLAGS.item_fg_json_path, dst_fg_path)


def main(argv):
Expand Down
Loading