diff --git a/.git_bin_path b/.git_bin_path index 925d16083..8f75b6292 100644 --- a/.git_bin_path +++ b/.git_bin_path @@ -1,7 +1,7 @@ {"leaf_name": "data/test", "leaf_file": ["data/test/batch_criteo_sample.tfrecord", "data/test/criteo_sample.tfrecord", "data/test/dwd_avazu_ctr_deepmodel_10w.csv", "data/test/embed_data.csv", "data/test/lookup_data.csv", "data/test/tag_kv_data.csv", "data/test/test.csv", "data/test/test_sample_weight.txt", "data/test/test_with_quote.csv"]} {"leaf_name": "data/test/export", "leaf_file": ["data/test/export/data.csv"]} {"leaf_name": "data/test/hpo_test/eval_val", "leaf_file": ["data/test/hpo_test/eval_val/events.out.tfevents.1597889819.j63d04245.sqa.eu95"]} -{"leaf_name": "data/test/inference", "leaf_file": ["data/test/inference/lookup_data_test80.csv", "data/test/inference/taobao_infer_data.txt"]} +{"leaf_name": "data/test/inference", "leaf_file": ["data/test/inference/lookup_data_test80.csv", "data/test/inference/taobao_infer_data.txt", "data/test/inference/taobao_infer_rtp_data.txt"]} {"leaf_name": "data/test/inference/fg_export_multi", "leaf_file": ["data/test/inference/fg_export_multi/saved_model.pb"]} {"leaf_name": "data/test/inference/fg_export_multi/assets", "leaf_file": ["data/test/inference/fg_export_multi/assets/pipeline.config"]} {"leaf_name": "data/test/inference/fg_export_multi/variables", "leaf_file": ["data/test/inference/fg_export_multi/variables/variables.data-00000-of-00001", "data/test/inference/fg_export_multi/variables/variables.index"]} @@ -20,6 +20,9 @@ {"leaf_name": "data/test/inference/tb_multitower_placeholder_rename_export", "leaf_file": ["data/test/inference/tb_multitower_placeholder_rename_export/saved_model.pb"]} {"leaf_name": "data/test/inference/tb_multitower_placeholder_rename_export/assets", "leaf_file": ["data/test/inference/tb_multitower_placeholder_rename_export/assets/pipeline.config"]} {"leaf_name": "data/test/inference/tb_multitower_placeholder_rename_export/variables", "leaf_file": ["data/test/inference/tb_multitower_placeholder_rename_export/variables/variables.data-00000-of-00001", "data/test/inference/tb_multitower_placeholder_rename_export/variables/variables.index"]} +{"leaf_name": "data/test/inference/tb_multitower_rtp_export", "leaf_file": ["data/test/inference/tb_multitower_rtp_export/saved_model.pb"]} +{"leaf_name": "data/test/inference/tb_multitower_rtp_export/assets", "leaf_file": ["data/test/inference/tb_multitower_rtp_export/assets/pipeline.config"]} +{"leaf_name": "data/test/inference/tb_multitower_rtp_export/variables", "leaf_file": ["data/test/inference/tb_multitower_rtp_export/variables/variables.data-00000-of-00001", "data/test/inference/tb_multitower_rtp_export/variables/variables.index"]} {"leaf_name": "data/test/latest_ckpt_test", "leaf_file": ["data/test/latest_ckpt_test/model.ckpt-500.data-00000-of-00001", "data/test/latest_ckpt_test/model.ckpt-500.index", "data/test/latest_ckpt_test/model.ckpt-500.meta"]} {"leaf_name": "data/test/rtp", "leaf_file": ["data/test/rtp/taobao_fg_pred.out", "data/test/rtp/taobao_test_bucketize_feature.txt", "data/test/rtp/taobao_test_feature.txt", "data/test/rtp/taobao_test_input.txt", "data/test/rtp/taobao_train_bucketize_feature.txt", "data/test/rtp/taobao_train_feature.txt", "data/test/rtp/taobao_train_input.txt", "data/test/rtp/taobao_valid.csv", "data/test/rtp/taobao_valid_feature.txt"]} {"leaf_name": "data/test/tb_data", "leaf_file": ["data/test/tb_data/taobao_ad_feature_gl", "data/test/tb_data/taobao_clk_edge_gl", "data/test/tb_data/taobao_multi_seq_test_data", "data/test/tb_data/taobao_multi_seq_train_data", "data/test/tb_data/taobao_noclk_edge_gl", "data/test/tb_data/taobao_test_data", "data/test/tb_data/taobao_test_data_for_expr", "data/test/tb_data/taobao_test_data_kd", "data/test/tb_data/taobao_train_data", "data/test/tb_data/taobao_train_data_for_expr", "data/test/tb_data/taobao_train_data_kd", "data/test/tb_data/taobao_user_profile_gl"]} diff --git a/.git_bin_url b/.git_bin_url index 1d0cd2136..c8394cd8a 100644 --- a/.git_bin_url +++ b/.git_bin_url @@ -1,7 +1,7 @@ {"leaf_path": "data/test", "sig": "656d73b4e78d0d71e98120050bc51387", "remote_path": "data/git_oss_sample_data/data_test_656d73b4e78d0d71e98120050bc51387"} {"leaf_path": "data/test/export", "sig": "c2e5ad1e91edb55b215ea108b9f14537", "remote_path": "data/git_oss_sample_data/data_test_export_c2e5ad1e91edb55b215ea108b9f14537"} {"leaf_path": "data/test/hpo_test/eval_val", "sig": "fef5f6cd659c35b713c1b8bcb97c698f", "remote_path": "data/git_oss_sample_data/data_test_hpo_test_eval_val_fef5f6cd659c35b713c1b8bcb97c698f"} -{"leaf_path": "data/test/inference", "sig": "e2c4b0f07ff8568eb7b8e1819326d296", "remote_path": "data/git_oss_sample_data/data_test_inference_e2c4b0f07ff8568eb7b8e1819326d296"} +{"leaf_path": "data/test/inference", "sig": "9725274cad0f27baf561ebfaf7946846", "remote_path": "data/git_oss_sample_data/data_test_inference_9725274cad0f27baf561ebfaf7946846"} {"leaf_path": "data/test/inference/fg_export_multi", "sig": "c6690cef053aed9e2011bbef90ef33e7", "remote_path": "data/git_oss_sample_data/data_test_inference_fg_export_multi_c6690cef053aed9e2011bbef90ef33e7"} {"leaf_path": "data/test/inference/fg_export_multi/assets", "sig": "7fe7a4525f5d46cc763172f5200e96e0", "remote_path": "data/git_oss_sample_data/data_test_inference_fg_export_multi_assets_7fe7a4525f5d46cc763172f5200e96e0"} {"leaf_path": "data/test/inference/fg_export_multi/variables", "sig": "1f9aad9744382c6d5b5f152d556d9b30", "remote_path": "data/git_oss_sample_data/data_test_inference_fg_export_multi_variables_1f9aad9744382c6d5b5f152d556d9b30"} @@ -20,6 +20,9 @@ {"leaf_path": "data/test/inference/tb_multitower_placeholder_rename_export", "sig": "dc05357e52fd574cba48165bc67af906", "remote_path": "data/git_oss_sample_data/data_test_inference_tb_multitower_placeholder_rename_export_dc05357e52fd574cba48165bc67af906"} {"leaf_path": "data/test/inference/tb_multitower_placeholder_rename_export/assets", "sig": "750925c4866bf1db8c3188f604271c72", "remote_path": "data/git_oss_sample_data/data_test_inference_tb_multitower_placeholder_rename_export_assets_750925c4866bf1db8c3188f604271c72"} {"leaf_path": "data/test/inference/tb_multitower_placeholder_rename_export/variables", "sig": "56850b4506014ce1bd3ba9b6d60e2770", "remote_path": "data/git_oss_sample_data/data_test_inference_tb_multitower_placeholder_rename_export_variables_56850b4506014ce1bd3ba9b6d60e2770"} +{"leaf_path": "data/test/inference/tb_multitower_rtp_export", "sig": "f1bc6238cfab648812afca093da5dd6b", "remote_path": "data/git_oss_sample_data/data_test_inference_tb_multitower_rtp_export_f1bc6238cfab648812afca093da5dd6b"} +{"leaf_path": "data/test/inference/tb_multitower_rtp_export/assets", "sig": "ae1cc9ec956fb900e5df45c4ec255c4b", "remote_path": "data/git_oss_sample_data/data_test_inference_tb_multitower_rtp_export_assets_ae1cc9ec956fb900e5df45c4ec255c4b"} +{"leaf_path": "data/test/inference/tb_multitower_rtp_export/variables", "sig": "efe52ef308fd6452f3b67fd04cdd22bd", "remote_path": "data/git_oss_sample_data/data_test_inference_tb_multitower_rtp_export_variables_efe52ef308fd6452f3b67fd04cdd22bd"} {"leaf_path": "data/test/latest_ckpt_test", "sig": "d41d8cd98f00b204e9800998ecf8427e", "remote_path": "data/git_oss_sample_data/data_test_latest_ckpt_test_d41d8cd98f00b204e9800998ecf8427e"} {"leaf_path": "data/test/rtp", "sig": "76cda60582617ddbb7cd5a49eb68a4b9", "remote_path": "data/git_oss_sample_data/data_test_rtp_76cda60582617ddbb7cd5a49eb68a4b9"} {"leaf_path": "data/test/tb_data", "sig": "c8136915b6e5e9d96b18448cf2e21d3d", "remote_path": "data/git_oss_sample_data/data_test_tb_data_c8136915b6e5e9d96b18448cf2e21d3d"} diff --git a/docs/source/models/mind.md b/docs/source/models/mind.md index 07f0a63bf..84037fa34 100644 --- a/docs/source/models/mind.md +++ b/docs/source/models/mind.md @@ -85,7 +85,7 @@ model_config:{ # use the same numer of capsules for all users const_caps_num: true } - + simi_pow: 20 l2_regularization: 1e-6 time_id_fea: "seq_ts_gap" @@ -101,7 +101,7 @@ model_config:{ - dnn: - hidden_units: dnn每一层的channel数 - use_bn: 是否使用batch_norm, 默认是true -- item_dnn: item侧的dnn参数, 配置同user_dnn +- item_dnn: item侧的dnn参数, 配置同user_dnn - note: item侧不能用batch_norm - pre_capsule_dnn: 进入capsule之前的dnn的配置 - 可选, 配置同user_dnn和item_dnn @@ -117,7 +117,7 @@ model_config:{ - squash_pow: 对squash加的power, 防止squash之后的向量值变得太小 - simi_pow: 对相似度做的倍数, 放大interests之间的差异 - embedding_regularization: 对embedding部分加regularization,防止overfit -- user_seq_combine: +- user_seq_combine: - CONCAT: 多个seq之间采取concat的方式融合 - SUM: 多个seq之间采取sum的方式融合, default是SUM - time_id_fea: time_id feature的name, 对应feature_config里面定义的特征 @@ -128,6 +128,7 @@ model_config:{ - 行为序列特征可以加上time_id, time_id经过1 dimension的embedding后, 在time维度进行softmax, 然后和其它sequence feature的embedding相乘 - time_id取值的方式可参考: + - 训练数据: Math.round((2 * Math.log1p((labelTime - itemTime) / 60.) / Math.log(2.))) + 1 - inference: Math.round((2 * Math.log1p((currentTime - itemTime) / 60.) / Math.log(2.))) + 1 - 此处的时间(labelTime, itemTime, currentTime) 为seconds @@ -136,17 +137,19 @@ model_config:{ - 使用增量训练,增量训练可以防止负采样的穿越。 -- 使用HPO对squash_pow[0.1 - 1.0]和simi_pow[10 - 100]进行搜索调优。 +- 使用HPO对squash_pow\[0.1 - 1.0\]和simi_pow\[10 - 100\]进行搜索调优。 - 要看的指标是召回率,准确率和兴趣损失,三个指标要一起看。 - 使用全网的点击数据来生成训练样本,全网的行为会更加丰富,这有利于mind模型的训练。 - 数据清洗: + - 把那些行为太少的item直接在构造行为序列的时候就挖掉 - 排除爬虫或者作弊用户 - 数据采样: + - mind模型的训练默认是以点击为目标 - 如果业务指标是到交易,那么可以对交易的样本重采样 @@ -155,9 +158,11 @@ model_config:{ [MIND_demo.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/mind_on_taobao_neg_sam.config) ### 效果评估 + 离线的效果评估主要看在测试集上的hitrate. 可以参考文档[效果评估](https://easyrec.oss-cn-beijing.aliyuncs.com/docs/recall_eval.pdf) #### 评估sql + ```sql pai -name tensorflow1120_cpu_ext -Dscript='oss://easyrec/deploy/easy_rec/python/tools/hitrate.py' @@ -204,15 +209,18 @@ pai -name tensorflow1120_cpu_ext - 1: Inner Product similarity - emb_dim: user / item表征向量的维度 - top_k: knn检索取top_k计算hitrate -- recall_type: +- recall_type: - u2i: user to item retrieval #### 评估结果 + 输出下面两张表 - mind_hitrate_details: + - 输出每一个user的hitrate = user_hits / user_recalls - 格式如下: + ```text id : bigint topk_ids : string @@ -221,10 +229,12 @@ pai -name tensorflow1120_cpu_ext bad_ids : string bad_dists : string ``` - + - mind_total_hitrate: + - 输出平均hitrate = SUM(user_hits) / SUM(user_recalls) - 格式如下: + ```text hitrate : double ``` diff --git a/docs/source/pre_check.md b/docs/source/pre_check.md index efc476594..1a1d5216a 100644 --- a/docs/source/pre_check.md +++ b/docs/source/pre_check.md @@ -3,12 +3,12 @@ 为解决用户常由于脏数据或配置错误的原因,导致训练失败,开发了预检查功能。 在训练时打开检查模式,或是训练前执行pre_check脚本,即会检查data_config配置及train_config部分配置,筛查全部数据,遇到异常则抛出相关信息,并给出修改意见。 - ### 命令 #### Local 方式一: 执行pre_check脚本: + ```bash PYTHONPATH=. python easy_rec/python/tools/pre_check.py --pipeline_config_path samples/model_config/din_on_taobao.config --data_input_path data/test/check_data/csv_data_for_check ``` @@ -16,17 +16,19 @@ PYTHONPATH=. python easy_rec/python/tools/pre_check.py --pipeline_config_path sa 方式二: 训练时打开检查模式(默认关闭): 该方式会影响训练速度,线上例行训练时不建议开启检查模式。 + ```bash python -m easy_rec.python.train_eval --pipeline_config_path samples/model_config/din_on_taobao.config --check_mode ``` + - pipeline_config_path config文件路径 - data_input_path 待检查的数据路径,不指定的话为pipeline_config_path中的train_input_path及eval_input_path - check_mode 默认False - #### On PAI 方式一: 执行pre_check脚本: + ```sql pai -name easy_rec_ext -project algo_public -Dcmd='check' @@ -42,6 +44,7 @@ pai -name easy_rec_ext -project algo_public 方式二: 训练时打开检查模式(默认关闭): 该方式会影响训练速度,线上例行训练时不建议开启检查模式。 + ```sql pai -name easy_rec_ext -project algo_public -Dcmd='train' diff --git a/docs/source/proto.html b/docs/source/proto.html index ed7b5fde0..92c119f74 100644 --- a/docs/source/proto.html +++ b/docs/source/proto.html @@ -63,57 +63,57 @@ } td p:nth-child(1) { - text-indent: 0; + text-indent: 0; } - - .field-table td:nth-child(1) { + + .field-table td:nth-child(1) { width: 10em; } - .field-table td:nth-child(2) { + .field-table td:nth-child(2) { width: 10em; } - .field-table td:nth-child(3) { + .field-table td:nth-child(3) { width: 6em; } - .field-table td:nth-child(4) { + .field-table td:nth-child(4) { width: auto; } - - .extension-table td:nth-child(1) { + + .extension-table td:nth-child(1) { width: 10em; } - .extension-table td:nth-child(2) { + .extension-table td:nth-child(2) { width: 10em; } - .extension-table td:nth-child(3) { + .extension-table td:nth-child(3) { width: 10em; } - .extension-table td:nth-child(4) { + .extension-table td:nth-child(4) { width: 5em; } - .extension-table td:nth-child(5) { + .extension-table td:nth-child(5) { width: auto; } - - .enum-table td:nth-child(1) { + + .enum-table td:nth-child(1) { width: 10em; } - .enum-table td:nth-child(2) { + .enum-table td:nth-child(2) { width: 10em; } - .enum-table td:nth-child(3) { + .enum-table td:nth-child(3) { width: auto; } - + .scalar-value-types-table tr { height: 3em; } - + #toc-container ul { list-style-type: none; padding-left: 1em; @@ -124,7 +124,7 @@ font-weight: bold; } - + .file-heading { width: 100%; display: table; @@ -140,7 +140,7 @@ display: table-cell; } - + .badge { width: 1.6em; height: 1.6em; @@ -160,7 +160,7 @@ } - + @@ -172,1076 +172,1076 @@

Table of Contents

- - + +

easy_rec/python/protos/autoint.proto

Top

 
-      
+
         

AutoInt


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
multi_head_num uint32 required
The number of heads Default: 1
multi_head_size uint32 required
The dimension of heads 
interacting_layer_num uint32 required
The number of interacting layers Default: 1
l2_regularization float required
 Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/collaborative_metric_learning.proto

Top

 
-      
+
         

CoMetricLearningI2I


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
session_id string optional
 
highway HighWayTower repeated
 
input string optional
 
dnn DNN required
 
l2_regularization float required
 Default: 0.0001
output_l2_normalized_emb bool required
 Default: true
sample_id string optional
 
circle_loss CircleLoss optional
 
multi_similarity_loss MultiSimilarityLoss optional
 
item_id string optional
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/dataset.proto

Top

 
-      
+
         

DatasetConfig


 
-        
+
           
-              
+
                 
-              
+
                 
@@ -1249,7 +1249,7 @@ 

DatasetConfig

- + @@ -1258,14 +1258,14 @@

DatasetConfig

For multiple target models such as MMOE multiple label_fields will be set. - + - + @@ -1273,14 +1273,14 @@

DatasetConfig

- + - + @@ -1289,7 +1289,7 @@

DatasetConfig

it is suggested to do full data shuffle before training especially when the performance of models is not good. Default: 32 - + @@ -1297,28 +1297,28 @@

DatasetConfig

- + - + - + - + @@ -1332,7 +1332,7 @@

DatasetConfig

for RTPInput and OdpsRTPInput it is usually set to '\002' Default: , - + @@ -1341,7 +1341,7 @@

DatasetConfig

or too large numbers(suggested be to small than number of the cores) Default: 8 - + @@ -1351,7 +1351,7 @@

DatasetConfig

such as '1,2,4', where 1,2 are label columns, and 4 is the feature column, column 0,3 are not used, - + @@ -1359,7 +1359,7 @@

DatasetConfig

- + @@ -1367,14 +1367,14 @@

DatasetConfig

- + - + @@ -1382,7 +1382,7 @@

DatasetConfig

- + @@ -1390,14 +1390,14 @@

DatasetConfig

- + - + @@ -1405,28 +1405,28 @@

DatasetConfig

- + - + - + - + @@ -1436,118 +1436,118 @@

DatasetConfig

and the number and the order of input_fields may not be the same as that in csv files. Default: false - + - + - + - + - + - + - +
FieldTypeLabelDescription
batch_size uint32 optional
mini batch size to use for training and evaluation. Default: 32
auto_expand_input_fields bool
set auto_expand_input_fields to true to
 auto_expand field[1-21] to field1, field2, ..., field21 Default: false
label_fields string
label_sep string repeated
label separator 
label_dim uint32
label dimensions which need to be set when there
 are labels have dimension > 1 
shuffle bool optional
whether to shuffle data Default: true
shuffle_buffer_size int32
num_epochs uint32
The number of times a data source is read. If set to zero, the data source
 will be reused indefinitely. Default: 0
prefetch_size uint32 optional
Number of decoded batches to prefetch. Default: 32
shard bool optional
shard dataset to 1/num_workers in distribute mode Default: false
input_type DatasetConfig.InputType required
 
separator string
num_parallel_calls uint32
selected_cols string
selected_col_types string
selected col types, only used for OdpsInput/OdpsInputV2
 to avoid error setting of data types 
input_fields DatasetConfig.Field
the input fields must be the same number and in the
 same order as data in csv files or odps tables 
rtp_separator string optional
for RTPInput only Default: ;
ignore_error bool
ignore some data errors
 it is not suggested to set this parameter Default: false
pai_worker_queue bool
whether to use pai global shuffle queue, only for OdpsInput,
 OdpsInputV2, OdpsRTPInputV2 Default: false
pai_worker_slice_num int32 optional
 Default: 100
chief_redundant bool
if true, one worker will duplicate the data of the chief node
 and undertake the gradient computation of the chief node Default: false
sample_weight string optional
input field for sample weight 
data_compression_type string optional
the compression type of tfrecord 
n_data_batch_tfrecord uint32 optional
n data for one feature in tfrecord 
with_header bool
negative_sampler NegativeSampler optional
 
negative_sampler_v2 NegativeSamplerV2 optional
 
hard_negative_sampler HardNegativeSampler optional
 
hard_negative_sampler_v2 HardNegativeSamplerV2 optional
 
negative_sampler_in_memory NegativeSamplerInMemory optional
 
eval_batch_size uint32 optional
 Default: 4096
- - - + + +

DatasetConfig.Field


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
input_name string required
 
input_type DatasetConfig.FieldType required
 Default: STRING
default_val string optional
 
input_dim uint32 optional
 Default: 1
input_shape uint32 optional
 Default: 1
- - - + + +

HardNegativeSampler

Weighted Random Sampling ItemID not in Batch and Sampling Hard Edge
- + - + @@ -1555,7 +1555,7 @@

HardNegativeSampler

- + @@ -1563,7 +1563,7 @@

HardNegativeSampler

- + @@ -1571,73 +1571,73 @@

HardNegativeSampler

- + - + - + - + - + - + - + - +
FieldTypeLabelDescription
user_input_path string
user data path
 userid weight 
item_input_path string
item data path
 itemid weight attrs 
hard_neg_edge_input_path string
hard negative edge path
 userid itemid weight 
num_sample uint32 required
number of negative sample 
num_hard_sample uint32 required
max number of hard negative sample 
attr_fields string repeated
field names of attrs in train data or eval data 
item_id_field string required
field name of item_id in train data or eval data 
user_id_field string required
field name of user_id in train data or eval data 
attr_delimiter string optional
 Default: :
num_eval_sample uint32 optional
 Default: 0
- - - + + +

HardNegativeSamplerV2

Weighted Random Sampling ItemID not with Edge and Sampling Hard Edge
- + - + @@ -1645,7 +1645,7 @@

HardNegativeSamplerV2

- + @@ -1653,7 +1653,7 @@

HardNegativeSamplerV2

- + @@ -1661,7 +1661,7 @@

HardNegativeSamplerV2

- + @@ -1669,73 +1669,73 @@

HardNegativeSamplerV2

- + - + - + - + - + - + - + - +
FieldTypeLabelDescription
user_input_path string
user data path
 userid weight 
item_input_path string
item data path
 itemid weight attrs 
pos_edge_input_path string
positive edge path
 userid itemid weight 
hard_neg_edge_input_path string
hard negative edge path
 userid itemid weight 
num_sample uint32 required
number of negative sample 
num_hard_sample uint32 required
max number of hard negative sample 
attr_fields string repeated
field names of attrs in train data or eval data 
item_id_field string required
field name of item_id in train data or eval data 
user_id_field string required
field name of user_id in train data or eval data 
attr_delimiter string optional
 Default: :
num_eval_sample uint32 optional
 Default: 0
- - - + + +

NegativeSampler

Weighted Random Sampling ItemID not in Batch
- + - + @@ -1743,59 +1743,59 @@

NegativeSampler

- + - + - + - + - + - +
FieldTypeLabelDescription
input_path string
sample data path
 itemid weight attrs 
num_sample uint32 required
number of negative sample 
attr_fields string repeated
field names of attrs in train data or eval data 
item_id_field string required
field name of item_id in train data or eval data 
attr_delimiter string optional
 Default: :
num_eval_sample uint32 optional
 Default: 0
- - - + + +

NegativeSamplerInMemory


 
-        
+
           
-              
+
                 
@@ -1803,59 +1803,59 @@ 

NegativeSamplerInMemory

- + - + - + - + - + - +
FieldTypeLabelDescription
input_path string
sample data path
 itemid weight attrs 
num_sample uint32 required
number of negative sample 
attr_fields string repeated
field names of attrs in train data or eval data 
item_id_field string required
field name of item_id in train data or eval data 
attr_delimiter string optional
 Default: :
num_eval_sample uint32 optional
 Default: 0
- - - + + +

NegativeSamplerV2

Weighted Random Sampling ItemID not with Edge
- + - + @@ -1863,7 +1863,7 @@

NegativeSamplerV2

- + @@ -1871,7 +1871,7 @@

NegativeSamplerV2

- + @@ -1879,58 +1879,58 @@

NegativeSamplerV2

- + - + - + - + - + - + - +
FieldTypeLabelDescription
user_input_path string
user data path
 userid weight 
item_input_path string
item data path
 itemid weight attrs 
pos_edge_input_path string
positive edge path
 userid itemid weight 
num_sample uint32 required
number of negative sample 
attr_fields string repeated
field names of attrs in train data or eval data 
item_id_field string required
field name of item_id in train data or eval data 
user_id_field string required
field name of user_id in train data or eval data 
attr_delimiter string optional
 Default: :
num_eval_sample uint32 optional
 Default: 0
- - - - + + + +

DatasetConfig.FieldType


         
@@ -1938,46 +1938,46 @@ 

DatasetConfig.FieldType

- + - + - + - + - + - + - +
NameNumberDescription
INT32 0
INT64 1
STRING 2
FLOAT 4
DOUBLE 5
BOOL 6
- +

DatasetConfig.InputType


         
@@ -1985,878 +1985,878 @@ 

DatasetConfig.InputType

- + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - +
NameNumberDescription
CSVInput 10
csv format input, could be used in local or hdfs
CSVInputV2 11
@Depreciated
CSVInputEx 12
extended csv format, allow quote in fields
OdpsInput 2
@Depreciated, has memory leak problem
OdpsInputV2 3
odps input, used on pai
DataHubInput 15
OdpsInputV3 9
RTPInput 4
RTPInputV2 5
OdpsRTPInput 601
OdpsRTPInputV2 602
TFRecordInput 7
BatchTFRecordInput 14
DummyInput 8
for the purpose to debug performance bottleneck of
 input pipelines
KafkaInput 13
HiveInput 16
CriteoInput 1001
- - - - - + + + + +

easy_rec/python/protos/data_source.proto

Top

 
-      
+
         

BinaryDataInput


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
category_path string repeated
support gfile.Glob 
dense_path string repeated
 
label_path string repeated
 
- - - + + +

DatahubServer


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
akId string required
 
akSecret string required
 
region string required
 
project string required
 
topic string required
 
shard_num uint32 required
 
life_cycle uint32 required
 
- - - + + +

KafkaServer


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
server string required
 
topic string required
 
group string required
 
partitions uint32 required
 
offset uint32 repeated
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/dbmtl.proto

Top

 
-      
+
         

DBMTL


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
bottom_dnn DNN optional
shared bottom dnn layer 
expert_dnn DNN optional
mmoe expert dnn layer definition 
num_expert uint32 optional
number of mmoe experts Default: 0
task_towers BayesTaskTower repeated
bayes task tower 
l2_regularization float optional
l2 regularization Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/dcn.proto

Top

 
-      
+
         

CrossTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
input string required
 
cross_num uint32 required
The number of cross layers Default: 3
- - - + + +

DCN


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
deep_tower Tower required
 
cross_tower CrossTower required
 
final_dnn DNN required
 
l2_regularization float required
 Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/deepfm.proto

Top

 
-      
+
         

DeepFM


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
dnn DNN required
 
final_dnn DNN optional
 
wide_output_dim uint32 optional
 Default: 1
wide_regularization float optional
deprecated Default: 0.0001
dense_regularization float optional
deprecated Default: 0.0001
l2_regularization float optional
 Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/dlrm.proto

Top

 
-      
+
         

DLRM


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
top_dnn DNN required
 
bot_dnn DNN required
 
arch_interaction_op string optional
options are: dot and cat Default: dot
arch_interaction_itself bool optional
whether a feature will interact with itself Default: false
arch_with_dense_feature bool optional
whether to include dense features after interaction Default: false
l2_regularization float optional
 Default: 1e-05
- - - - - - - - + + + + + + + +

easy_rec/python/protos/dnn.proto

Top

 
-      
+
         

DNN


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
hidden_units uint32 repeated
hidden units for each layer 
dropout_ratio float repeated
ratio of dropout 
activation string optional
activation function Default: tf.nn.relu
use_bn bool optional
use batch normalization Default: true
- - - - - - - - + + + + + + + +

easy_rec/python/protos/dropoutnet.proto

Top

 
-      
+
         

DropoutNet


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
user_content DNN required
 
user_preference DNN required
 
item_content DNN required
 
item_preference DNN required
 
user_tower DNN required
 
item_tower DNN required
 
l2_regularization float required
 Default: 0
user_dropout_rate float required
 Default: 0
item_dropout_rate float required
 Default: 0.5
softmax_loss SoftmaxCrossEntropyWithNegativeMining optional
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/dssm.proto

Top

 
-      
+
         

DSSM


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
user_tower DSSMTower required
 
item_tower DSSMTower required
 
l2_regularization float required
 Default: 0.0001
simi_func Similarity optional
 Default: COSINE
scale_simi bool optional
add a layer for scaling the similarity Default: true
item_id string optional
 
ignore_in_batch_neg_sam bool required
 Default: false
- - - + + +

DSSMTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
id string required
 
dnn DNN required
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/eas_serving.proto

Top

 
-      
+
         

Config


 
-        
+
           
-              
+
                 
@@ -2864,14 +2864,14 @@ 

Config

- + - + @@ -2879,503 +2879,503 @@

Config

- + - + - + - + - +
FieldTypeLabelDescription
column_delim string
例如输入特征为"1005,109;0;93eaba74",此时分号分割的为column,
 逗号分割的为每个column的多个feature, 下划线分割为feature名字和对应的value。 
feature_delim string
 
hash string
指定字符串hash分桶的算法,支持HarmHash(对应于tf.strings.to_hash_bucket_fast())
 和SipHash(对应于tf.strings.to_hash_bucket_strong())两种字符串hash分桶算法 
embeddings Config.EmbeddingsEntry repeated
embedding_name to embedding 
embedding_max_norm Config.EmbeddingMaxNormEntry repeated
指定embedding lookup的结果的最大L2-norm 
embedding_combiner Config.EmbeddingCombinerEntry repeated
指定embedding的combiner策略,支持sum, mean和sqrtn 
model Model
 
- - - + + +

Config.EmbeddingCombinerEntry


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
key string
 
value string
 
- - - + + +

Config.EmbeddingMaxNormEntry


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
key string
 
value float
 
- - - + + +

Config.EmbeddingsEntry


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
key string
 
value Embedding
 
- - - + + +

Embedding


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
partition_num int32
指定该embedding切分的总数 
parts EmbeddingPart repeated
 
- - - + + +

EmbeddingPart


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
embedding_part_path string
指定EmbeddingPartData(*.pb)所在的路径 
partition_id int32
指定该embedding part所属第几个part 
shape int64 repeated
指定该embedding part的shape(可以从EmbeddingPartData中读取) 
deploy_strategy string
embedding part的部署策略, 支持本地部署(local)和远程部署(remote) 
- - - + + +

EmbeddingPartData


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
shape int64 repeated
Shape of the embedding 
data float repeated
Data 
- - - + + +

Model


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
model_path string
指定模型所在路径,便于加载模型 
model_signature_name string
指定模型的sinature的名字 
model_inputs ModelInput repeated
model input description 
- - - + + +

ModelInput


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
feature_name string
 
embedding_name string
 
placeholder_name string
 
weight_name string
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/easy_rec_model.proto

Top

 
-      
+
         

DummyModel

for input performance test
- - - + + +

EasyRecModel


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
@@ -3384,35 +3384,35 @@ 

EasyRecModel

add regularization to all variables with "embedding_weights:" in name Default: 0 - + - + - + - + - + @@ -3420,274 +3420,274 @@

EasyRecModel

- + - + - + - +
FieldTypeLabelDescription
model_class string required
 
feature_groups FeatureGroupConfig repeated
actually input layers, each layer produce a group of feature 
dummy DummyModel optional
 
wide_and_deep WideAndDeep optional
 
deepfm DeepFM optional
 
multi_tower MultiTower optional
 
fm FM optional
 
dcn DCN optional
 
autoint AutoInt optional
 
dlrm DLRM optional
 
dssm DSSM optional
 
mind MIND optional
 
dropoutnet DropoutNet optional
 
metric_learning CoMetricLearningI2I optional
 
mmoe MMoE optional
 
esmm ESMM optional
 
dbmtl DBMTL optional
 
simple_multi_task SimpleMultiTask optional
 
ple PLE optional
 
rocket_launching RocketLaunching optional
 
seq_att_groups SeqAttGroupConfig repeated
 
embedding_regularization float
loss_type LossType optional
 Default: CLASSIFICATION
num_class uint32 optional
 Default: 1
use_embedding_variable bool optional
 Default: false
kd KD repeated
 
restore_filters string
filter variables matching any pattern in restore_filters
 common filters are Adam, Momentum, etc. 
variational_dropout VariationalDropoutLayer optional
 
losses Loss repeated
 
f1_reweight_loss F1ReweighedLoss optional
 
- - - + + +

KD

for knowledge distillation
- + - + - + - + - + - + - + - + - + - +
FieldTypeLabelDescription
loss_name string optional
 
pred_name string required
 
pred_is_logits bool optional
default to be logits Default: true
soft_label_name string required
for CROSS_ENTROPY_LOSS, soft_label must be logits instead of probs 
label_is_logits bool optional
default to be logits Default: true
loss_type LossType required
currently only support CROSS_ENTROPY_LOSS and L2_LOSS 
loss_weight float optional
 Default: 1
temperature float optional
only for loss_type == CROSS_ENTROPY_LOSS Default: 1
- - - - - - - - + + + + + + + +

easy_rec/python/protos/esmm.proto

Top

 
-      
+
         

ESMM


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
groups Tower repeated
 
ctr_tower TaskTower required
 
cvr_tower TaskTower required
 
l2_regularization float required
 Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/eval.proto

Top

 
-      
+
         

AUC


 
-        
+
           
-              
+
                 
-              
+
             
FieldTypeLabelDescription
num_thresholds uint32 optional
 Default: 200
- - - + + +

Accuracy


 
-        
 
-        
-      
+
+
+
         

AvgPrecisionAtTopK


 
-        
+
           
-              
+
                 
-              
+
             
FieldTypeLabelDescription
topk uint32 optional
 Default: 5
- - - + + +

EvalConfig

Message for configuring EasyRecModel evaluation jobs (eval.py).
- + - + - + - + - + - + @@ -3695,139 +3695,139 @@

EvalConfig

- + - +
FieldTypeLabelDescription
num_examples uint32 optional
Number of examples to process of evaluation. Default: 0
eval_interval_secs uint32 optional
How often to run evaluation. Default: 300
max_evals uint32 optional
Maximum number of times to run evaluation. If set to 0, will run forever. Default: 0
save_graph bool optional
Whether the TensorFlow graph used for evaluation should be saved to disk. Default: false
metrics_set EvalMetrics
Type of metrics to use for evaluation.
 possible values: 
eval_online bool optional
Evaluation online with batch forward data of training Default: false
- - - + + +

EvalMetrics


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
auc AUC optional
 
recall_at_topk RecallAtTopK optional
 
mean_absolute_error MeanAbsoluteError optional
 
mean_squared_error MeanSquaredError optional
 
accuracy Accuracy optional
 
max_f1 Max_F1 optional
 
root_mean_squared_error RootMeanSquaredError optional
 
gauc GAUC optional
 
session_auc SessionAUC optional
 
recall Recall optional
 
precision Precision optional
 
precision_at_topk AvgPrecisionAtTopK optional
 
- - - + + +

GAUC


 
-        
+
           
-              
+
                 
-              
+
                 
@@ -3837,97 +3837,97 @@ 

GAUC

* "mean_by_sample_num": weighted mean with sample num of different users * "mean_by_positive_num": weighted mean with positive sample num of different users Default: mean - +
FieldTypeLabelDescription
uid_field string required
uid field name 
reduction string
- - - + + +

Max_F1


 
-        
 
-        
-      
+
+
+
         

MeanAbsoluteError


 
-        
 
-        
-      
+
+
+
         

MeanSquaredError


 
-        
 
-        
-      
+
+
+
         

Precision


 
-        
 
-        
-      
+
+
+
         

Recall


 
-        
 
-        
-      
+
+
+
         

RecallAtTopK


 
-        
+
           
-              
+
                 
-              
+
             
FieldTypeLabelDescription
topk uint32 optional
 Default: 5
- - - + + +

RootMeanSquaredError


 
-        
 
-        
-      
+
+
+
         

SessionAUC


 
-        
+
           
-              
+
                 
-              
+
                 
@@ -3937,38 +3937,38 @@ 

SessionAUC

* "mean_by_sample_num": weighted mean with sample num of different sessions * "mean_by_positive_num": weighted mean with positive sample num of different sessions Default: mean - +
FieldTypeLabelDescription
session_id_field string required
session id field name 
reduction string
- - - - - - - - + + + + + + + +

easy_rec/python/protos/export.proto

Top

 
-      
+
         

ExportConfig

Message for configuring exporting models.
- + - + @@ -3977,7 +3977,7 @@

ExportConfig

which is only supported by classification model right now, while other models support static batch_size Default: -1 - + @@ -3988,28 +3988,28 @@

ExportConfig

latest: export the best model according to best_exporter_metric none: do not perform export Default: final - + - + - + - + @@ -4018,258 +4018,258 @@

ExportConfig

early_stop_func(eval_results, early_stop_params) return True if should stop - + - + - + - + - + - + - + - + - + - +
FieldTypeLabelDescription
batch_size int32
exporter_type string
best_exporter_metric string optional
the metric used to determine the best checkpoint Default: auc
metric_bigger bool optional
metric value the bigger the best Default: true
enable_early_stop bool optional
enable early stop Default: false
early_stop_func string
early_stop_params string optional
custom early stop parameters 
max_check_steps int32 optional
early stop max check steps Default: 10000
multi_placeholder bool optional
each feature has a placeholder Default: true
exports_to_keep int32 optional
export to keep, only for exporter_type in [best, latest] Default: 1
multi_value_fields MultiValueFields optional
multi value field list 
placeholder_named_by_input bool optional
is placeholder named by input Default: false
filter_inputs bool optional
filter out inputs, only keep effective ones Default: true
export_features bool optional
export the original feature values as string Default: false
export_rtp_outputs bool optional
export the outputs required by RTP Default: false
- - - + + +

MultiValueFields


 
-        
+
           
-              
+
                 
-              
+
             
FieldTypeLabelDescription
input_name string repeated
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/feature_config.proto

Top

 
-      
+
         

AttentionCombiner


 
-        
 
-        
-      
+
+
+
         

FeatureConfig


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
@@ -4278,293 +4278,293 @@ 

FeatureConfig

scientific format is not used. in default it is not allowed to convert float/double to string Default: -1 - + - + - + - + - + - + - + - +
FieldTypeLabelDescription
feature_name string optional
 
input_names string repeated
input field names: must be included in DatasetConfig.input_fields 
feature_type FeatureConfig.FeatureType required
 Default: IdFeature
embedding_name string optional
 
embedding_dim uint32 optional
 Default: 0
hash_bucket_size uint64 optional
 Default: 0
num_buckets uint64 optional
for categorical_column_with_identity Default: 0
boundaries double repeated
only for raw features 
separator string optional
separator with in features Default: |
kv_separator string optional
delimeter to separator key from value 
seq_multi_sep string optional
delimeter to separate sequence multi-values 
vocab_file string optional
 
vocab_list string repeated
 
shared_names string repeated
many other field share this config 
lookup_max_sel_elem_num int32 optional
lookup max select element number, default 10 Default: 10
max_partitions int32 optional
max_partitions Default: 1
combiner string optional
combiner Default: mean
initializer Initializer optional
embedding initializer 
precision int32
min_val double optional
normalize raw feature to [0-1] Default: 0
max_val double optional
 Default: 0
raw_input_dim uint32 optional
raw feature of multiple dimensions Default: 1
sequence_combiner SequenceCombiner optional
sequence feature combiner 
sub_feature_type FeatureConfig.FeatureType optional
sub feature type for sequence feature Default: IdFeature
sequence_length uint32 optional
sequence length Default: 1
expression string optional
for expr feature 
- - - + + +

FeatureConfigV2


 
-        
+
           
-              
+
                 
-              
+
             
FieldTypeLabelDescription
features FeatureConfig repeated
 
- - - + + +

FeatureGroupConfig


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
group_name string optional
 
feature_names string repeated
 
wide_deep WideOrDeep optional
 Default: DEEP
sequence_features SeqAttGroupConfig repeated
 
- - - + + +

MultiHeadAttentionCombiner


 
-        
 
-        
-      
+
+
+
         

SeqAttGroupConfig


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
group_name string optional
 
seq_att_map SeqAttMap repeated
 
tf_summary bool optional
 Default: false
seq_dnn DNN optional
 
allow_key_search bool optional
 Default: false
- - - + + +

SeqAttMap


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
key string repeated
 
hist_seq string repeated
 
- - - + + +

SequenceCombiner


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
attention AttentionCombiner optional
 
multi_head_attention MultiHeadAttentionCombiner optional
 
text_cnn TextCnnCombiner optional
 
- - - + + +

TextCnnCombiner


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
filter_sizes uint32 repeated
 
num_filters uint32 repeated
 
- - - - + + + +

FeatureConfig.FeatureType


         
@@ -4572,52 +4572,52 @@ 

FeatureConfig.FeatureType

- + - + - + - + - + - + - + - +
NameNumberDescription
IdFeature 0
RawFeature 1
TagFeature 2
ComboFeature 3
LookupFeature 4
SequenceFeature 5
ExprFeature 6
- +

FeatureConfig.FieldType


         
@@ -4625,46 +4625,46 @@ 

FeatureConfig.FieldType

- + - + - + - + - + - + - +
NameNumberDescription
INT32 0
INT64 1
STRING 2
FLOAT 4
DOUBLE 5
BOOL 6
- +

WideOrDeep


         
@@ -4672,664 +4672,664 @@ 

WideOrDeep

- + - + - + - +
NameNumberDescription
DEEP 0
WIDE 1
WIDE_AND_DEEP 2
- - - - - + + + + +

easy_rec/python/protos/fm.proto

Top

 
-      
+
         

FM


 
-        
+
           
-              
+
                 
-              
+
             
FieldTypeLabelDescription
l2_regularization float optional
 Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/hive_config.proto

Top

 
-      
+
         

HiveConfig


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
host string required
hive master's ip 
port uint32 required
hive port Default: 10000
username string required
hive username 
database string required
hive database Default: default
table_name string required
 
hash_fields string required
 
limit_num uint32 optional
 Default: 0
fetch_size uint32 required
 Default: 512
- - - - - - - - + + + + + + + +

easy_rec/python/protos/hyperparams.proto

Top

 
-      
+
         

ConstantInitializer


 
-        
+
           
-              
+
                 
-              
+
             
FieldTypeLabelDescription
consts float repeated
 
- - - + + +

GlorotNormalInitializer


 
-        
 
-        
-      
+
+
+
         

Initializer

Proto with one-of field for initializers.
- + - + - + - + - + - +
FieldTypeLabelDescription
truncated_normal_initializer TruncatedNormalInitializer optional
 
random_normal_initializer RandomNormalInitializer optional
 
glorot_normal_initializer GlorotNormalInitializer optional
 
constant_initializer ConstantInitializer optional
 
- - - + + +

L1L2Regularizer

Configuration proto for L2 Regularizer.
- + - + - + - +
FieldTypeLabelDescription
scale_l1 float optional
 Default: 1
scale_l2 float optional
 Default: 1
- - - + + +

L1Regularizer

Configuration proto for L1 Regularizer.
- + - + - +
FieldTypeLabelDescription
scale float optional
 Default: 1
- - - + + +

L2Regularizer

Configuration proto for L2 Regularizer.
- + - + - +
FieldTypeLabelDescription
scale float optional
 Default: 1
- - - + + +

RandomNormalInitializer

Configuration proto for random normal initializer. See
https://www.tensorflow.org/api_docs/python/tf/random_normal_initializer
- + - + - + - +
FieldTypeLabelDescription
mean float optional
 Default: 0
stddev float optional
 Default: 1
- - - + + +

Regularizer

Proto with one-of field for regularizers.
- + - + - + - + - +
FieldTypeLabelDescription
l1_regularizer L1Regularizer optional
 
l2_regularizer L2Regularizer optional
 
l1_l2_regularizer L1L2Regularizer optional
 
- - - + + +

TruncatedNormalInitializer

Configuration proto for truncated normal initializer. See
https://www.tensorflow.org/api_docs/python/tf/truncated_normal_initializer
- + - + - + - +
FieldTypeLabelDescription
mean float optional
 Default: 0
stddev float optional
 Default: 1
- - - - - - - - + + + + + + + +

easy_rec/python/protos/layer.proto

Top

 
-      
+
         

HighWayTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
input string required
 
emb_size uint32 required
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/loss.proto

Top

 
-      
+
         

CircleLoss


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
margin float required
 Default: 0.25
gamma float required
 Default: 32
- - - + + +

F1ReweighedLoss


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
f1_beta_square float required
 Default: 1
label_smoothing float required
 Default: 0
- - - + + +

Loss


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
loss_type LossType required
 
weight float required
 Default: 1
- - - + + +

MultiSimilarityLoss


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
alpha float required
 Default: 2
beta float required
 Default: 50
lamb float required
 Default: 1
eps float required
 Default: 0.1
- - - + + +

SoftmaxCrossEntropyWithNegativeMining


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
num_negative_samples uint32 required
 
margin float required
 Default: 0
gamma float required
 Default: 1
coefficient_of_support_vector float required
 Default: 1
- - - - + + + +

LossType


         
@@ -5337,148 +5337,148 @@ 

LossType

- + - + - + - + - + - + - + - + - + - + - +
NameNumberDescription
CLASSIFICATION 0
L2_LOSS 1
SIGMOID_L2_LOSS 2
CROSS_ENTROPY_LOSS 3
crossentropy loss/log loss
SOFTMAX_CROSS_ENTROPY 4
CIRCLE_LOSS 5
MULTI_SIMILARITY_LOSS 6
SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING 7
PAIR_WISE_LOSS 8
F1_REWEIGHTED_LOSS 9
- - - - - + + + + +

easy_rec/python/protos/mind.proto

Top

 
-      
+
         

Capsule


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
@@ -5486,45 +5486,45 @@ 

Capsule

- +
FieldTypeLabelDescription
max_k uint32 optional
max number of high capsules Default: 5
max_seq_len uint32 required
max behaviour sequence length 
high_dim uint32 required
high capsule embedding vector dimension 
num_iters uint32 optional
number EM iterations Default: 3
routing_logits_scale float optional
routing logits scale Default: 20
routing_logits_stddev float optional
routing logits initial stddev Default: 1
squash_pow float optional
squash power Default: 1
scale_ratio float optional
output ratio Default: 1
const_caps_num bool
constant interest number
 in default, use log(seq_len) Default: false
- - - + + +

MIND


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
@@ -5532,21 +5532,21 @@ 

MIND

- + - + - + @@ -5554,65 +5554,65 @@

MIND

- + - + - + - + - + - + - + - +
FieldTypeLabelDescription
pre_capsule_dnn DNN optional
preprocessing dnn before entering capsule layer 
user_dnn DNN required
dnn layers applied on user_context(none sequence features) 
concat_dnn DNN required
concat user and capsule dnn 
user_seq_combine MIND.UserSeqCombineMethod
method to combine several user sequences
 such as item_ids, category_ids Default: SUM
item_dnn DNN required
dnn layers applied on item features 
capsule_config Capsule required
 
simi_pow float
similarity power, the paper says that the big
 the better Default: 10
simi_func Similarity optional
 Default: COSINE
scale_simi bool optional
add a layer for scaling the similarity Default: true
l2_regularization float required
 Default: 0.0001
time_id_fea string optional
 
item_id string optional
 
ignore_in_batch_neg_sam bool optional
 Default: false
max_interests_simi float optional
 Default: 1
- - - - + + + +

MIND.UserSeqCombineMethod


         
@@ -5620,740 +5620,740 @@ 

MIND.UserSeqCombineMethod

- + - + - +
NameNumberDescription
CONCAT 0
SUM 1
- - - - - + + + + +

easy_rec/python/protos/mmoe.proto

Top

 
-      
+
         

ExpertTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
expert_name string required
 
dnn DNN required
 
- - - + + +

MMoE


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
experts ExpertTower repeated
deprecated: original mmoe experts config 
expert_dnn DNN optional
mmoe expert dnn layer definition 
num_expert uint32 optional
number of mmoe experts Default: 0
task_towers TaskTower repeated
task tower 
l2_regularization float required
l2 regularization Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/multi_tower.proto

Top

 
-      
+
         

BSTTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
input string required
 
seq_len uint32 required
 Default: 5
multi_head_size uint32 required
 Default: 4
- - - + + +

DINTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
input string required
 
dnn DNN required
 
- - - + + +

MultiTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
towers Tower repeated
 
final_dnn DNN required
 
l2_regularization float required
 Default: 0.0001
din_towers DINTower repeated
 
bst_towers BSTTower repeated
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/optimizer.proto

Top

 
-      
+
         

AdagradOptimizer

Configuration message for the AdagradOptimizer
See: https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
- + - + - +
FieldTypeLabelDescription
learning_rate LearningRate optional
 
- - - + + +

AdamAsyncOptimizer

Only available on pai-tf, which has better performance than AdamOptimizer
- + - + - + - + - +
FieldTypeLabelDescription
learning_rate LearningRate optional
 
beta1 float optional
 Default: 0.9
beta2 float optional
 Default: 0.999
- - - + + +

AdamAsyncWOptimizer


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
learning_rate LearningRate optional
 
weight_decay float optional
 Default: 1e-06
beta1 float optional
 Default: 0.9
beta2 float optional
 Default: 0.999
- - - + + +

AdamOptimizer

Configuration message for the AdamOptimizer
See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
- + - + - + - + - +
FieldTypeLabelDescription
learning_rate LearningRate optional
 
beta1 float optional
 Default: 0.9
beta2 float optional
 Default: 0.999
- - - + + +

AdamWOptimizer


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
learning_rate LearningRate optional
 
weight_decay float optional
 Default: 1e-06
beta1 float optional
 Default: 0.9
beta2 float optional
 Default: 0.999
- - - + + +

ConstantLearningRate

Configuration message for a constant learning rate.
- + - + - +
FieldTypeLabelDescription
learning_rate float optional
 Default: 0.002
- - - + + +

CosineDecayLearningRate

Configuration message for a cosine decaying learning rate as defined in
utils/learning_schedules.py
- + - + - + - + - + - + - +
FieldTypeLabelDescription
learning_rate_base float optional
 Default: 0.002
total_steps uint32 optional
 Default: 4000000
warmup_learning_rate float optional
 Default: 0.0002
warmup_steps uint32 optional
 Default: 10000
hold_base_rate_steps uint32 optional
 Default: 0
- - - + + +

ExponentialDecayLearningRate

Configuration message for an exponentially decaying learning rate.
See https://www.tensorflow.org/versions/master/api_docs/python/train/ \
decaying_the_learning_rate#exponential_decay
- + - + - + - + - + - + - + - + - +
FieldTypeLabelDescription
initial_learning_rate float optional
 Default: 0.002
decay_steps uint32 optional
 Default: 4000000
decay_factor float optional
 Default: 0.95
staircase bool optional
 Default: true
burnin_learning_rate float optional
 Default: 0
burnin_steps uint32 optional
 Default: 0
min_learning_rate float optional
 Default: 0
- - - + + +

FtrlOptimizer


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
learning_rate LearningRate optional
optional float learning_rate = 1 [default=1e-4]; 
learning_rate_power float optional
 Default: -0.5
initial_accumulator_value float optional
 Default: 0.1
l1_reg float optional
 Default: 0
l2_reg float optional
 Default: 0
l2_shrinkage_reg float optional
 Default: 0
- - - + + +

LearningRate

Configuration message for optimizer learning rate.
- + - + - + - + - + - + - + - +
FieldTypeLabelDescription
constant_learning_rate ConstantLearningRate optional
 
exponential_decay_learning_rate ExponentialDecayLearningRate optional
 
manual_step_learning_rate ManualStepLearningRate optional
 
cosine_decay_learning_rate CosineDecayLearningRate optional
 
poly_decay_learning_rate PolyDecayLearningRate optional
 
transformer_learning_rate TransformerLearningRate optional
 
- - - + + +

ManualStepLearningRate

Configuration message for a manually defined learning rate schedule.
- + - + - + - + @@ -6361,566 +6361,566 @@

ManualStepLearningRate

- +
FieldTypeLabelDescription
initial_learning_rate float optional
 Default: 0.002
schedule ManualStepLearningRate.LearningRateSchedule repeated
 
warmup bool
Whether to linearly interpolate learning rates for steps in
 [0, schedule[0].step]. Default: false
- - - + + +

ManualStepLearningRate.LearningRateSchedule


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
step uint32 optional
 
learning_rate float optional
 Default: 0.002
- - - + + +

MomentumOptimizer

Configuration message for the MomentumOptimizer
See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
- + - + - + - +
FieldTypeLabelDescription
learning_rate LearningRate optional
 
momentum_optimizer_value float optional
 Default: 0.9
- - - + + +

MomentumWOptimizer


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
learning_rate LearningRate optional
 
weight_decay float optional
 Default: 1e-06
momentum_optimizer_value float optional
 Default: 0.9
- - - + + +

Optimizer

Top level optimizer message.
- + - + - + - + - + - + - + - + - + - + - + - + - + - +
FieldTypeLabelDescription
rms_prop_optimizer RMSPropOptimizer optional
 
momentum_optimizer MomentumOptimizer optional
 
adam_optimizer AdamOptimizer optional
 
momentumw_optimizer MomentumWOptimizer optional
 
adamw_optimizer AdamWOptimizer optional
 
adam_async_optimizer AdamAsyncOptimizer optional
 
adagrad_optimizer AdagradOptimizer optional
 
ftrl_optimizer FtrlOptimizer optional
 
adam_asyncw_optimizer AdamAsyncWOptimizer optional
 
use_moving_average bool optional
 Default: false
moving_average_decay float optional
 Default: 0.9999
embedding_learning_rate_multiplier float optional
 
- - - + + +

PolyDecayLearningRate

Configuration message for a poly decaying learning rate.
See https://www.tensorflow.org/api_docs/python/tf/train/polynomial_decay.
- + - + - + - + - + - +
FieldTypeLabelDescription
learning_rate_base float required
 
total_steps int64 required
 
power float required
 
end_learning_rate float optional
 Default: 0
- - - + + +

RMSPropOptimizer

Configuration message for the RMSPropOptimizer
See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
- + - + - + - + - + - +
FieldTypeLabelDescription
learning_rate LearningRate optional
 
momentum_optimizer_value float optional
 Default: 0.9
decay float optional
 Default: 0.9
epsilon float optional
 Default: 1
- - - + + +

TransformerLearningRate


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
learning_rate_base float required
 
hidden_size int32 required
 
warmup_steps int32 required
 
step_scaling_rate float optional
 Default: 1
- - - - - - - - + + + + + + + +

easy_rec/python/protos/pdn.proto

Top

 
-      
 
-      
 
-      
 
-      
-    
-      
+
+
+
+
+
+
       

easy_rec/python/protos/pipeline.proto

Top

 
-      
+
         

EasyRecConfig


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
train_input_path string optional
 
kafka_train_input KafkaServer optional
 
datahub_train_input DatahubServer optional
 
hive_train_input HiveConfig optional
 
binary_train_input BinaryDataInput optional
 
eval_input_path string optional
 
kafka_eval_input KafkaServer optional
 
datahub_eval_input DatahubServer optional
 
hive_eval_input HiveConfig optional
 
binary_eval_input BinaryDataInput optional
 
model_dir string required
 
train_config TrainConfig optional
train config, including optimizer, weight decay, num_steps and so on 
eval_config EvalConfig optional
 
data_config DatasetConfig optional
 
feature_configs FeatureConfig repeated
for compatibility 
feature_config FeatureConfigV2 optional
 
model_config EasyRecModel required
recommendation model config 
export_config ExportConfig optional
 
fg_json_path string optional
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/ple.proto

Top

 
-      
+
         

ExtractionNetwork


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
@@ -6928,14 +6928,14 @@ 

ExtractionNetwork

- + - + @@ -6943,141 +6943,141 @@

ExtractionNetwork

- +
FieldTypeLabelDescription
network_name string required
 
expert_num_per_task uint32 required
number of experts per task 
share_num uint32
number of experts for share
 For the last extraction_network, no need to configure this 
task_expert_net DNN required
dnn network of experts per task 
share_expert_net DNN
dnn network of experts for share
 For the last extraction_network, no need to configure this 
- - - + + +

PLE


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
extraction_networks ExtractionNetwork repeated
extraction network 
task_towers TaskTower repeated
task tower 
l2_regularization float optional
l2 regularization Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/rocket_launching.proto

Top

 
-      
+
         

RocketLaunching


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
share_dnn DNN required
 
booster_dnn DNN required
 
light_dnn DNN required
 
l2_regularization float optional
 Default: 0.0001
feature_based_distillation bool optional
 Default: false
feature_distillation_function Similarity optional
COSINE = 0; EUCLID = 1; Default: COSINE
- - - - - - - - + + + + + + + +

easy_rec/python/protos/simi.proto

Top

 
-      
 
-      
+
+
         

Similarity


         
@@ -7085,172 +7085,172 @@ 

Similarity

- + - + - + - +
NameNumberDescription
COSINE 0
INNER_PRODUCT 1
EUCLID 2
- - - - - + + + + +

easy_rec/python/protos/simple_multi_task.proto

Top

 
-      
+
         

SimpleMultiTask


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
task_towers TaskTower repeated
 
l2_regularization float required
 Default: 0.0001
- - - - - - - - + + + + + + + +

easy_rec/python/protos/tower.proto

Top

 
-      
+
         

BayesTaskTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
@@ -7262,170 +7262,170 @@ 

BayesTaskTower

prediction weights optional float prediction_weight = 14 [default = 1.0]; Default: 1 - +
FieldTypeLabelDescription
tower_name string required
task name for the task tower 
label_name string optional
label for the task, default is label_fields by order 
metrics_set EvalMetrics repeated
metrics for the task 
loss_type LossType optional
loss for the task Default: CLASSIFICATION
num_class uint32 optional
num_class for multi-class classification loss Default: 1
dnn DNN optional
task specific dnn 
relation_tower_names string repeated
related tower names 
relation_dnn DNN optional
relation dnn 
weight float optional
training loss weights Default: 1
task_space_indicator_label string optional
label name for indcating the sample space for the task tower 
in_task_space_weight float optional
the loss weight for sample in the task space Default: 1
out_task_space_weight float
- - - + + +

TaskTower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
tower_name string required
task name for the task tower 
label_name string optional
label for the task, default is label_fields by order 
metrics_set EvalMetrics repeated
metrics for the task 
loss_type LossType optional
loss for the task Default: CLASSIFICATION
num_class uint32 optional
num_class for multi-class classification loss Default: 1
dnn DNN optional
task specific dnn 
weight float optional
training loss weights Default: 1
task_space_indicator_label string optional
label name for indcating the sample space for the task tower 
in_task_space_weight float optional
the loss weight for sample in the task space Default: 1
out_task_space_weight float optional
the loss weight for sample out the task space Default: 1
- - - + + +

Tower


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
input string required
 
dnn DNN required
 
- - - - - - - - + + + + + + + +

easy_rec/python/protos/train.proto

Top

 
-      
+
         

TrainConfig

Message for configuring EasyRecModel training jobs (train.py).
Next id: 25
- + - + - + - + @@ -7433,21 +7433,21 @@

TrainConfig

- + - + - + @@ -7455,7 +7455,7 @@

TrainConfig

- + @@ -7465,7 +7465,7 @@

TrainConfig

raw, hash, multi_map, list, parallel in general, multi_map runs faster than other options. Default: multi_map - + @@ -7473,56 +7473,56 @@

TrainConfig

- + - + - + - + - + - + - + - + @@ -7531,21 +7531,21 @@

TrainConfig

- mirrored: MirroredStrategy, single machine and multiple devices; - collective: CollectiveAllReduceStrategy, multiple machines and multiple devices. Default: NoStrategy - + - + - + @@ -7554,51 +7554,51 @@

TrainConfig

grpc++: https://help.aliyun.com/document_detail/173157.html?spm=5176.10695662.1996646101.searchclickresult.3ebf450evuaPT3 star_server: https://help.aliyun.com/document_detail/173154.html?spm=a2c4g.11186623.6.627.39ad7e3342KOX4 - + - + - + - + - + - +
FieldTypeLabelDescription
optimizer_config Optimizer repeated
optimizer options 
gradient_clipping_by_norm float optional
If greater than 0, clips gradients by this value. Default: 0
num_steps uint32
Number of steps to train the models: if 0, will train the model
 indefinitely. Default: 0
fine_tune_checkpoint string optional
Checkpoint to restore variables from. 
fine_tune_ckpt_var_map string optional
 
sync_replicas bool
Whether to synchronize replicas during training.
 In case so, build a SyncReplicateOptimizer Default: true
sparse_accumulator_type string
startup_delay_steps float
Number of training steps between replica startup.
 This flag must be set to 0 if sync_replicas is set to true. Default: 15
save_checkpoints_steps uint32 optional
Step interval for saving checkpoint Default: 1000
save_checkpoints_secs uint32 optional
Seconds interval for saving checkpoint 
keep_checkpoint_max uint32 optional
Max checkpoints to keep Default: 10
save_summary_steps uint32 optional
Save summaries every this many steps. Default: 1000
log_step_count_steps uint32 optional
The frequency global step/sec and the loss will be logged during training. Default: 10
is_profiling bool optional
profiling or not Default: false
force_restore_shape_compatible bool optional
if variable shape is incompatible, clip or pad variables in checkpoint Default: false
train_distribute DistributionStrategy
num_gpus_per_worker int32 optional
Number of gpus per machine Default: 1
summary_model_vars bool optional
summary model variables or not Default: false
protocol string
inter_op_parallelism_threads int32 optional
inter_op_parallelism_threads Default: 0
intra_op_parallelism_threads int32 optional
intra_op_parallelism_threads Default: 0
tensor_fuse bool optional
tensor fusion on PAI-TF Default: false
write_graph bool optional
write graph into graph.pbtxt and summary or not Default: true
freeze_gradient string repeated
match variable patterns to freeze 
- - - - + + + +

DistributionStrategy


         
@@ -7606,130 +7606,130 @@ 

DistributionStrategy

- + - + - + - + - + - + - +
NameNumberDescription
NoStrategy 0
use old SyncReplicasOptimizer for ParameterServer training
PSStrategy 1
PSStrategy with multiple gpus on one node could not work
 on pai-tf, could only work on TF >=1.15
MirroredStrategy 2
could only work on PaiTF or TF >=1.15
 single worker multiple gpu mode
CollectiveAllReduceStrategy 3
Depreciated
ExascaleStrategy 4
currently not working good
MultiWorkerMirroredStrategy 5
multi worker multi gpu mode
 see tf.distribute.experimental.MultiWorkerMirroredStrategy
- - - - - + + + + +

easy_rec/python/protos/variational_dropout.proto

Top

 
-      
+
         

VariationalDropoutLayer


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
             
FieldTypeLabelDescription
regularization_lambda float optional
regularization coefficient lambda Default: 0.01
embedding_wise_variational_dropout bool optional
variational_dropout dimension Default: false
- - - - - - - - + + + + + + + +

easy_rec/python/protos/wide_and_deep.proto

Top

 
-      
+
         

WideAndDeep


 
-        
+
           
-              
+
                 
-              
+
                 
-              
+
                 
@@ -7737,28 +7737,28 @@ 

WideAndDeep

- + - +
FieldTypeLabelDescription
wide_output_dim uint32 required
 Default: 1
dnn DNN required
 
final_dnn DNN
if set, the output of dnn and wide part are concatenated and
 passed to the final_dnn; otherwise, they are summarized 
l2_regularization float optional
 Default: 0.0001
- - - - - - - + + + + + + +

Scalar Value Types

@@ -7766,7 +7766,7 @@

Scalar Value Types

- + @@ -7774,7 +7774,7 @@

Scalar Value Types

- + @@ -7782,7 +7782,7 @@

Scalar Value Types

- + @@ -7790,7 +7790,7 @@

Scalar Value Types

- + @@ -7798,7 +7798,7 @@

Scalar Value Types

- + @@ -7806,7 +7806,7 @@

Scalar Value Types

- + @@ -7814,7 +7814,7 @@

Scalar Value Types

- + @@ -7822,7 +7822,7 @@

Scalar Value Types

- + @@ -7830,7 +7830,7 @@

Scalar Value Types

- + @@ -7838,7 +7838,7 @@

Scalar Value Types

- + @@ -7846,7 +7846,7 @@

Scalar Value Types

- + @@ -7854,7 +7854,7 @@

Scalar Value Types

- + @@ -7862,7 +7862,7 @@

Scalar Value Types

- + @@ -7870,7 +7870,7 @@

Scalar Value Types

- + @@ -7878,7 +7878,7 @@

Scalar Value Types

- + @@ -7886,9 +7886,8 @@

Scalar Value Types

- +
.proto TypeNotesC++ TypeJava TypePython Type
double double float
float float float
int32 Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint32 instead. int int
int64 Uses variable-length encoding. Inefficient for encoding negative numbers – if your field is likely to have negative values, use sint64 instead. long int/long
uint32 Uses variable-length encoding. int int/long
uint64 Uses variable-length encoding. long int/long
sint32 Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int32s. int int
sint64 Uses variable-length encoding. Signed int value. These more efficiently encode negative numbers than regular int64s. long int/long
fixed32 Always four bytes. More efficient than uint32 if values are often greater than 2^28. int int
fixed64 Always eight bytes. More efficient than uint64 if values are often greater than 2^56. long int/long
sfixed32 Always four bytes. int int
sfixed64 Always eight bytes. long int/long
bool boolean boolean
string A string must always contain UTF-8 encoded or 7-bit ASCII text. String str/unicode
bytes May contain any arbitrary sequence of bytes. ByteString str
- diff --git a/easy_rec/python/builders/loss_builder.py b/easy_rec/python/builders/loss_builder.py index bee92b203..8ea7393f3 100644 --- a/easy_rec/python/builders/loss_builder.py +++ b/easy_rec/python/builders/loss_builder.py @@ -20,7 +20,7 @@ def build(loss_type, label, pred, loss_weight=1.0, num_class=1, **kwargs): label, logits=pred, weights=loss_weight, **kwargs) else: assert label.dtype in [tf.int32, tf.int64], \ - "label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy." + 'label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy.' return tf.losses.sparse_softmax_cross_entropy( labels=label, logits=pred, weights=loss_weight, **kwargs) elif loss_type == LossType.CROSS_ENTROPY_LOSS: diff --git a/easy_rec/python/core/distribute_metrics.py b/easy_rec/python/core/distribute_metrics.py index 93e742364..055873bbd 100644 --- a/easy_rec/python/core/distribute_metrics.py +++ b/easy_rec/python/core/distribute_metrics.py @@ -5,6 +5,7 @@ import numpy as np import tensorflow as tf from sklearn import metrics as sklearn_metrics + from easy_rec.python.utils import pai_util from easy_rec.python.utils.shape_utils import get_shape_list @@ -16,6 +17,7 @@ if tf.__version__ >= '2.0': tf = tf.compat.v1 + def max_f1(label, predictions): """Calculate the largest F1 metric under different thresholds. @@ -134,11 +136,12 @@ def session_auc(labels, predictions, session_ids, reduction='mean'): """ return _separated_auc_impl(labels, predictions, session_ids, reduction) + def distribute_metric_learning_recall_at_k(k, - embeddings, - labels, - session_ids=None, - embed_normed=False): + embeddings, + labels, + session_ids=None, + embed_normed=False): """Computes the recall_at_k metric for metric learning. Args: @@ -181,6 +184,7 @@ def distribute_metric_learning_recall_at_k(k, else: raise ValueError('k should be a `int` or a list/tuple/set of int.') + def _get_matrix_mask_indices(matrix, num_rows=None): if num_rows is None: num_rows = get_shape_list(matrix)[0] @@ -202,11 +206,12 @@ def _get_matrix_mask_indices(matrix, num_rows=None): result = tf.where(result >= 0, result, max_index_per_row) return result + def distribute_metric_learning_average_precision_at_k(k, - embeddings, - labels, - session_ids=None, - embed_normed=False): + embeddings, + labels, + session_ids=None, + embed_normed=False): # make sure embedding should be l2-normalized if not embed_normed: embeddings = tf.nn.l2_normalize(embeddings, axis=1) @@ -221,7 +226,8 @@ def distribute_metric_learning_average_precision_at_k(k, mask = tf.logical_and(sessions_equal, mask) label_indices = _get_matrix_mask_indices(mask) if isinstance(k, int): - return distribute_metrics_tf.average_precision_at_k(label_indices, sim_mat, k) + return distribute_metrics_tf.average_precision_at_k(label_indices, sim_mat, + k) if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))): metrics = {} for kk in k: @@ -231,4 +237,4 @@ def distribute_metric_learning_average_precision_at_k(k, label_indices, sim_mat, kk) return metrics else: - raise ValueError('k should be a `int` or a list/tuple/set of int.') \ No newline at end of file + raise ValueError('k should be a `int` or a list/tuple/set of int.') diff --git a/easy_rec/python/core/sampler.py b/easy_rec/python/core/sampler.py index b9a0338e9..64b1e1eb4 100644 --- a/easy_rec/python/core/sampler.py +++ b/easy_rec/python/core/sampler.py @@ -14,6 +14,7 @@ import tensorflow as tf from easy_rec.python.protos.dataset_pb2 import DatasetConfig +from easy_rec.python.utils.tf_utils import get_tf_type try: import graphlearn as gl @@ -52,19 +53,6 @@ def _get_np_type(field_type): return type_map[field_type] -def _get_tf_type(field_type): - type_map = { - DatasetConfig.INT32: tf.int32, - DatasetConfig.INT64: tf.int64, - DatasetConfig.STRING: tf.string, - DatasetConfig.BOOL: tf.bool, - DatasetConfig.FLOAT: tf.float32, - DatasetConfig.DOUBLE: tf.double - } - assert field_type in type_map, 'invalid type: %s' % field_type - return type_map[field_type] - - class BaseSampler(object): _instance_lock = threading.Lock() @@ -134,7 +122,7 @@ def _build_field_types(self, fields): self._attr_types.append(field.input_type) self._attr_gl_types.append(_get_gl_type(field.input_type)) self._attr_np_types.append(_get_np_type(field.input_type)) - self._attr_tf_types.append(_get_tf_type(field.input_type)) + self._attr_tf_types.append(get_tf_type(field.input_type)) @classmethod def instance(cls, *args, **kwargs): @@ -150,7 +138,7 @@ def __del__(self): def _parse_nodes(self, nodes): if self._log_first_n > 0: logging.info('num_example=%d num_eval_example=%d node_num=%d' % - (self._num_sample, self._num_eval_sample, len(nodes.ids))) + (self._num_sample, self._num_eval_sample, len(nodes.ids))) self._log_first_n -= 1 features = [] int_idx = 0 diff --git a/easy_rec/python/eval.py b/easy_rec/python/eval.py index 7408f23e7..394a767a2 100644 --- a/easy_rec/python/eval.py +++ b/easy_rec/python/eval.py @@ -8,7 +8,9 @@ import tensorflow as tf from tensorflow.python.lib.io import file_io -from easy_rec.python.main import evaluate, distribute_evaluate +from easy_rec.python.main import distribute_evaluate +from easy_rec.python.main import evaluate + from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_distribute_eval_worker_num_on_ds # NOQA if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -38,7 +40,7 @@ def main(argv): if FLAGS.odps_config: os.environ['ODPS_CONFIG_FILE_PATH'] = FLAGS.odps_config - + if FLAGS.is_on_ds and FLAGS.distribute_eval: set_tf_config_and_get_distribute_eval_worker_num_on_ds() @@ -53,11 +55,12 @@ def main(argv): pipeline_config_path = FLAGS.pipeline_config_path if FLAGS.distribute_eval: - eval_result = distribute_evaluate(pipeline_config_path, FLAGS.checkpoint_path, - FLAGS.eval_input_path) + eval_result = distribute_evaluate(pipeline_config_path, + FLAGS.checkpoint_path, + FLAGS.eval_input_path) else: eval_result = evaluate(pipeline_config_path, FLAGS.checkpoint_path, - FLAGS.eval_input_path) + FLAGS.eval_input_path) if eval_result is not None: # when distribute evaluate, only master has eval_result. for key in sorted(eval_result): @@ -66,7 +69,7 @@ def main(argv): continue logging.info('%s: %s' % (key, str(eval_result[key]))) else: - logging.info("Eval result in master worker.") + logging.info('Eval result in master worker.') if __name__ == '__main__': diff --git a/easy_rec/python/inference/predictor.py b/easy_rec/python/inference/predictor.py index 74d4868e6..b4e43c5e6 100644 --- a/easy_rec/python/inference/predictor.py +++ b/easy_rec/python/inference/predictor.py @@ -5,6 +5,7 @@ from __future__ import print_function import abc +import json import logging import math import os @@ -18,10 +19,14 @@ from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import signature_constants -from easy_rec.python.utils import pai_util +from easy_rec.python.protos.dataset_pb2 import DatasetConfig +from easy_rec.python.utils.check_utils import check_split from easy_rec.python.utils.config_util import get_configs_from_pipeline_file +from easy_rec.python.utils.config_util import get_input_name_from_fg_json +from easy_rec.python.utils.hive_utils import HiveUtils from easy_rec.python.utils.input_utils import get_type_defaults from easy_rec.python.utils.load_class import get_register_class_meta +from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -153,7 +158,6 @@ def _get_input_fields_from_pipeline_config(self, model_path): '%s not exists, default values maybe inconsistent with the values used in training.' % pipeline_path) return {} - pipeline_config = get_configs_from_pipeline_file(pipeline_path) input_fields = pipeline_config.data_config.input_fields input_fields_info = { @@ -312,7 +316,7 @@ def predict(self, input_data_dict, output_names=None): class Predictor(PredictorInterface): - def __init__(self, model_path, profiling_file=None): + def __init__(self, model_path, profiling_file=None, fg_json_path=None): """Initialize a `Predictor`. Args: @@ -320,6 +324,7 @@ def __init__(self, model_path, profiling_file=None): profiling_file: profiling result file, default None. if not None, predict function will use Timeline to profiling prediction time, and the result json will be saved to profiling_file + fg_json_path: fg.json file """ self._predictor_impl = PredictorImpl(model_path, profiling_file) self._inputs_map = self._predictor_impl._inputs_map @@ -330,6 +335,9 @@ def __init__(self, model_path, profiling_file=None): self._is_multi_placeholder = self._predictor_impl._is_multi_placeholder self._input_fields = self._predictor_impl._input_fields_list + fg_json = self._get_fg_json(fg_json_path, model_path) + self._all_input_names = get_input_name_from_fg_json(fg_json) + logging.info('all_input_names: %s' % self._all_input_names) @property def input_names(self): @@ -349,27 +357,74 @@ def output_names(self): """ return list(self._outputs_map.keys()) - def predict_impl(self, - input_table, - output_table, - all_cols='', - all_col_types='', - selected_cols='', - reserved_cols='', - output_cols=None, - batch_size=1024, - slice_id=0, - slice_num=1, - input_sep=',', - output_sep=chr(1)): + def _get_defaults(self, col_name, col_type='string'): + if col_name in self._input_fields_info: + col_type, default_val = self._input_fields_info[col_name] + default_val = get_type_defaults(col_type, default_val) + logging.info('col_name: %s, default_val: %s' % (col_name, default_val)) + else: + defaults = {'string': '', 'double': 0.0, 'bigint': 0} + assert col_type in defaults, 'invalid col_type: %s, col_type: %s' % ( + col_name, col_type) + default_val = defaults[col_type] + logging.info( + 'col_name: %s, default_val: %s.[not defined in saved_model_dir/assets/pipeline.config]' + % (col_name, default_val)) + return default_val + + def _parse_line(self, line): + pass + + def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, + slice_id): + pass + + def _get_writer(self, output_path, slice_id): + pass + + def _get_reserved_cols(self, reserved_cols): + pass + + @property + def out_of_range_exception(self): + return None + + def _write_line(self, table_writer, outputs): + pass + + def _get_fg_json(self, fg_json_path, model_path): + if fg_json_path and gfile.Exists(fg_json_path): + logging.info('load fg_json_path: ', fg_json_path) + with tf.gfile.GFile(fg_json_path, 'r') as fin: + fg_json = json.loads(fin.read()) + else: + fg_json_path = os.path.join(model_path, 'assets/fg.json') + if gfile.Exists(fg_json_path): + logging.info('load fg_json_path: ', fg_json_path) + with tf.gfile.GFile(fg_json_path, 'r') as fin: + fg_json = json.loads(fin.read()) + else: + fg_json = {} + return fg_json + + def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs): + pass + + def predict_impl( + self, + input_path, + output_path, + reserved_cols='', + output_cols=None, + batch_size=1024, + slice_id=0, + slice_num=1, + ): """Predict table input with loaded model. Args: - input_table: table/file_path to read - output_table: table/file_path to write - all_cols: union of columns - all_col_types: data types of the columns - selected_cols: included column names, comma separated, such as "a,b,c" + input_path: table/file_path to read + output_path: table/file_path to write reserved_cols: columns to be copy to output_table, comma separated, such as "a,b" output_cols: output columns, comma separated, such as "y float, embedding string", the output names[y, embedding] must be in saved_model output_names @@ -377,46 +432,10 @@ def predict_impl(self, slice_id: when multiple workers write the same table, each worker should be assigned different slice_id, which is usually slice_id slice_num: table slice number - input_sep: separator of input file. - output_sep: separator of predict result file. """ - if pai_util.is_on_pai(): - self.predict_table( - input_table, - output_table, - all_cols=all_cols, - all_col_types=all_col_types, - selected_cols=selected_cols, - reserved_cols=reserved_cols, - output_cols=output_cols, - batch_size=batch_size, - slice_id=slice_id, - slice_num=slice_num) - else: - self.predict_csv( - input_table, - output_table, - reserved_cols=reserved_cols, - output_cols=output_cols, - batch_size=batch_size, - slice_id=slice_id, - slice_num=slice_num, - input_sep=input_sep, - output_sep=output_sep) - - def predict_csv(self, input_path, output_path, reserved_cols, output_cols, - batch_size, slice_id, slice_num, input_sep, output_sep): - record_defaults = [ - get_type_defaults(*self._input_fields_info[col_name]) for col_name in self._input_fields - ] - - if reserved_cols == 'ALL_COLUMNS': - reserved_cols = self._input_fields - else: - reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] if output_cols is None or output_cols == 'ALL_COLUMNS': - output_cols = sorted(self._predictor_impl.output_names) - logging.info('predict output cols: %s' % output_cols) + self._output_cols = sorted(self._predictor_impl.output_names) + logging.info('predict output cols: %s' % self._output_cols) else: # specified as score float,embedding string tmp_cols = [] @@ -425,179 +444,19 @@ def predict_csv(self, input_path, output_path, reserved_cols, output_cols, continue tmp_keys = x.split(' ') tmp_cols.append(tmp_keys[0].strip()) - output_cols = tmp_cols + self._output_cols = tmp_cols with tf.Graph().as_default(), tf.Session() as sess: num_parallel_calls = 8 - file_paths = [] - for x in input_path.split(','): - file_paths.extend(gfile.Glob(x)) - assert len(file_paths) > 0, 'match no files with %s' % input_path - - dataset = tf.data.Dataset.from_tensor_slices(file_paths) - parallel_num = min(num_parallel_calls, len(file_paths)) - dataset = dataset.interleave( - tf.data.TextLineDataset, - cycle_length=parallel_num, - num_parallel_calls=parallel_num) - dataset = dataset.shard(slice_num, slice_id) - logging.info('batch_size = %d' % batch_size) - dataset = dataset.batch(batch_size) - dataset = dataset.prefetch(buffer_size=64) - - def _parse_csv(line): - - def _check_data(line): - sep = input_sep - if type(sep) != type(str): - sep = sep.encode('utf-8') - field_num = len(line[0].split(sep)) - assert field_num == len(record_defaults), 'sep[%s] maybe invalid: field_num=%d, required_num=%d' \ - % (sep, field_num, len(record_defaults)) - return True - - check_op = tf.py_func(_check_data, [line], Tout=tf.bool) - with tf.control_dependencies([check_op]): - fields = tf.decode_csv( - line, - field_delim=',', - record_defaults=record_defaults, - name='decode_csv') - - inputs = {self._input_fields[x]: fields[x] for x in range(len(fields))} - return inputs - - dataset = dataset.map(_parse_csv, num_parallel_calls=num_parallel_calls) + dataset = self._get_dataset(input_path, num_parallel_calls, batch_size, + slice_num, slice_id) + dataset = dataset.map( + self._parse_line, num_parallel_calls=num_parallel_calls) iterator = dataset.make_one_shot_iterator() all_dict = iterator.get_next() - - if not gfile.Exists(output_path): - gfile.MakeDirs(output_path) - res_path = os.path.join(output_path, 'slice_%d.csv' % slice_id) - table_writer = gfile.GFile(res_path, 'w') - - input_names = self._predictor_impl.input_names - progress = 0 - sum_t0, sum_t1, sum_t2 = 0, 0, 0 - pred_cnt = 0 - table_writer.write(output_sep.join(output_cols + reserved_cols) + '\n') - while True: - try: - ts0 = time.time() - all_vals = sess.run(all_dict) - - ts1 = time.time() - input_vals = {k: all_vals[k] for k in input_names} - outputs = self._predictor_impl.predict(input_vals, output_cols) - - for x in output_cols: - if outputs[x].dtype == np.object: - outputs[x] = [val.decode('utf-8') for val in outputs[x]] - for k in reserved_cols: - if all_vals[k].dtype == np.object: - all_vals[k] = [val.decode('utf-8') for val in all_vals[k]] - - ts2 = time.time() - reserve_vals = [outputs[x] for x in output_cols] + \ - [all_vals[k] for k in reserved_cols] - outputs = [x for x in zip(*reserve_vals)] - pred_cnt += len(outputs) - outputs = '\n'.join( - [output_sep.join([str(i) for i in output]) for output in outputs]) - table_writer.write(outputs + '\n') - - ts3 = time.time() - progress += 1 - sum_t0 += (ts1 - ts0) - sum_t1 += (ts2 - ts1) - sum_t2 += (ts3 - ts2) - except tf.errors.OutOfRangeError: - break - if progress % 100 == 0: - logging.info('progress: batch_num=%d sample_num=%d' % - (progress, progress * batch_size)) - logging.info('time_stats: read: %.2f predict: %.2f write: %.2f' % - (sum_t0, sum_t1, sum_t2)) - logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' % - (sum_t0, sum_t1, sum_t2)) - table_writer.close() - logging.info('Predict %s done.' % input_path) - logging.info('Predict size: %d.' % pred_cnt) - - def predict_table(self, - input_table, - output_table, - all_cols, - all_col_types, - selected_cols, - reserved_cols, - output_cols=None, - batch_size=1024, - slice_id=0, - slice_num=1): - - def _get_defaults(col_name, col_type): - if col_name in self._input_fields_info: - col_type, default_val = self._input_fields_info[col_name] - default_val = get_type_defaults(col_type, default_val) - logging.info('col_name: %s, default_val: %s' % (col_name, default_val)) - else: - defaults = {'string': '', 'double': 0.0, 'bigint': 0} - assert col_type in defaults, 'invalid col_type: %s, col_type: %s' % ( - col_name, col_type) - default_val = defaults[col_type] - logging.info( - 'col_name: %s, default_val: %s.[not defined in saved_model_dir/assets/pipeline.config]' - % (col_name, default_val)) - return default_val - - all_cols = [x.strip() for x in all_cols.split(',') if x != ''] - all_col_types = [x.strip() for x in all_col_types.split(',') if x != ''] - reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] - - if output_cols is None: - output_cols = self._predictor_impl.output_names - else: - # specified as score float,embedding string - tmp_cols = [] - for x in output_cols.split(','): - if x.strip() == '': - continue - tmp_keys = x.split(' ') - tmp_cols.append(tmp_keys[0].strip()) - output_cols = tmp_cols - - record_defaults = [ - _get_defaults(col_name, col_type) - for col_name, col_type in zip(all_cols, all_col_types) - ] - with tf.Graph().as_default(), tf.Session() as sess: - num_parallel_calls = 8 - input_table = input_table.split(',') - dataset = tf.data.TableRecordDataset([input_table], - record_defaults=record_defaults, - slice_id=slice_id, - slice_count=slice_num, - selected_cols=','.join(all_cols)) - - logging.info('batch_size = %d' % batch_size) - dataset = dataset.batch(batch_size) - dataset = dataset.prefetch(buffer_size=64) - - def _parse_table(*fields): - fields = list(fields) - field_dict = {all_cols[i]: fields[i] for i in range(len(fields))} - return field_dict - - dataset = dataset.map(_parse_table, num_parallel_calls=num_parallel_calls) - iterator = dataset.make_one_shot_iterator() - all_dict = iterator.get_next() - - import common_io - table_writer = common_io.table.TableWriter( - output_table, slice_id=slice_id) - + self._reserved_cols = self._get_reserved_cols(reserved_cols) input_names = self._predictor_impl.input_names + table_writer = self._get_writer(output_path, slice_id) def _parse_value(all_vals): if self._is_multi_placeholder: @@ -605,11 +464,22 @@ def _parse_value(all_vals): feature_vals = all_vals[SINGLE_PLACEHOLDER_FEATURE_KEY] split_index = [] split_vals = {} - for i, k in enumerate(input_names): - split_index.append(k) - split_vals[k] = [] + fg_input_size = len(feature_vals[0].decode('utf-8').split('\002')) + if fg_input_size == len(input_names): + for i, k in enumerate(input_names): + split_index.append(k) + split_vals[k] = [] + else: + assert self._all_input_names, 'must set fg_json_path when use fg input' + assert fg_input_size == len(self._all_input_names), \ + 'The size of features in fg_json != the size of fg input. ' \ + 'The size of features in fg_json is: %s; The size of fg input is: %s' % \ + (fg_input_size, len(self._all_input_names)) + for i, k in enumerate(self._all_input_names): + split_index.append(k) + split_vals[k] = [] for record in feature_vals: - split_records = record.split('\002') + split_records = record.decode('utf-8').split('\002') for i, r in enumerate(split_records): split_vals[split_index[i]].append(r) return {k: np.array(split_vals[k]) for k in input_names} @@ -625,25 +495,27 @@ def _parse_value(all_vals): ts1 = time.time() input_vals = _parse_value(all_vals) - # logging.info('input names = %s' % input_names) - # logging.info('input vals = %s' % input_vals) - outputs = self._predictor_impl.predict(input_vals, output_cols) + outputs = self._predictor_impl.predict(input_vals, self._output_cols) + for x in self._output_cols: + if outputs[x].dtype == np.object: + outputs[x] = [val.decode('utf-8') for val in outputs[x]] + for k in self._reserved_cols: + if all_vals[k].dtype == np.object: + all_vals[k] = [val.decode('utf-8') for val in all_vals[k]] ts2 = time.time() - reserve_vals = [all_vals[k] for k in reserved_cols - ] + [outputs[x] for x in output_cols] - indices = list(range(0, len(reserve_vals))) + reserve_vals = self._get_reserve_vals(self._reserved_cols, + self._output_cols, all_vals, + outputs) outputs = [x for x in zip(*reserve_vals)] + self._write_line(table_writer, outputs) - table_writer.write(outputs, indices, allow_type_cast=False) ts3 = time.time() progress += 1 sum_t0 += (ts1 - ts0) sum_t1 += (ts2 - ts1) sum_t2 += (ts3 - ts2) - except tf.python_io.OutOfRangeException: - break - except tf.errors.OutOfRangeError: + except self.out_of_range_exception: break if progress % 100 == 0: logging.info('progress: batch_num=%d sample_num=%d' % @@ -653,7 +525,7 @@ def _parse_value(all_vals): logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' % (sum_t0, sum_t1, sum_t2)) table_writer.close() - logging.info('Predict %s done.' % input_table) + logging.info('Predict %s done.' % input_path) def predict(self, input_data_dict_list, output_names=None, batch_size=1): """Predict input data with loaded model. @@ -710,3 +582,310 @@ def batch(self, data_list): for key in batch_input: batch_input[key] = np.array(batch_input[key]) return batch_input + + +class CSVPredictor(Predictor): + + def __init__(self, + model_path, + data_config, + fg_json_path=None, + profiling_file=None, + selected_cols='', + input_sep=',', + output_sep=chr(1)): + super(CSVPredictor, self).__init__(model_path, profiling_file, fg_json_path) + self._input_sep = input_sep + self._output_sep = output_sep + input_type = DatasetConfig.InputType.Name(data_config.input_type).lower() + + if 'rtp' in input_type: + self._is_rtp = True + else: + self._is_rtp = False + if selected_cols: + self._selected_cols = [int(x) for x in selected_cols.split(',')] + else: + self._selected_cols = None + + def _get_reserved_cols(self, reserved_cols): + if reserved_cols == 'ALL_COLUMNS': + if self._is_rtp: + idx = 0 + reserved_cols = [] + for x in range(len(self._record_defaults) - 1): + if not self._selected_cols or x in self._selected_cols[:-1]: + reserved_cols.append(self._input_fields[idx]) + idx += 1 + else: + reserved_cols.append('no_used_%d' % x) + reserved_cols.append(SINGLE_PLACEHOLDER_FEATURE_KEY) + else: + reserved_cols = self._input_fields + else: + reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] + return reserved_cols + + def _parse_line(self, line): + check_list = [ + tf.py_func( + check_split, [line, self._input_sep, + len(self._record_defaults)], + Tout=tf.bool) + ] + with tf.control_dependencies(check_list): + fields = tf.decode_csv( + line, + field_delim=self._input_sep, + record_defaults=self._record_defaults, + name='decode_csv') + if self._is_rtp: + inputs = {} + idx = 0 + for x in range(len(self._record_defaults) - 1): + if not self._selected_cols or x in self._selected_cols[:-1]: + inputs[self._input_fields[idx]] = fields[x] + idx += 1 + else: + inputs['no_used_%d' % x] = fields[x] + inputs[SINGLE_PLACEHOLDER_FEATURE_KEY] = fields[-1] + else: + inputs = {self._input_fields[x]: fields[x] for x in range(len(fields))} + return inputs + + def _get_num_cols(self, file_paths): + # try to figure out number of fields from one file + num_cols = -1 + with tf.gfile.GFile(file_paths[0], 'r') as fin: + num_lines = 0 + for line_str in fin: + line_tok = line_str.strip().split(self._input_sep) + if num_cols != -1: + assert num_cols == len(line_tok), \ + 'num selected cols is %d, not equal to %d, current line is: %s, please check input_sep and data.' % \ + (num_cols, len(line_tok), line_str) + num_cols = len(line_tok) + num_lines += 1 + if num_lines > 10: + break + logging.info('num selected cols = %d' % num_cols) + return num_cols + + def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, + slice_id): + file_paths = [] + for x in input_path.split(','): + file_paths.extend(gfile.Glob(x)) + assert len(file_paths) > 0, 'match no files with %s' % input_path + + if self._is_rtp: + num_cols = self._get_num_cols(file_paths) + self._record_defaults = ['' for _ in range(num_cols)] + if not self._selected_cols: + self._selected_cols = list(range(num_cols)) + for col_idx in self._selected_cols[:-1]: + col_name = self._input_fields[col_idx] + default_val = self._get_defaults(col_name) + self._record_defaults[col_idx] = default_val + else: + self._record_defaults = [ + self._get_defaults(col_name) for col_name in self._input_fields + ] + + dataset = tf.data.Dataset.from_tensor_slices(file_paths) + parallel_num = min(num_parallel_calls, len(file_paths)) + dataset = dataset.interleave( + tf.data.TextLineDataset, + cycle_length=parallel_num, + num_parallel_calls=parallel_num) + dataset = dataset.shard(slice_num, slice_id) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(buffer_size=64) + return dataset + + def _get_writer(self, output_path, slice_id): + if not gfile.Exists(output_path): + gfile.MakeDirs(output_path) + res_path = os.path.join(output_path, 'slice_%d.csv' % slice_id) + table_writer = gfile.GFile(res_path, 'w') + table_writer.write( + self._output_sep.join(self._output_cols + self._reserved_cols) + '\n') + return table_writer + + def _write_line(self, table_writer, outputs): + outputs = '\n'.join( + [self._output_sep.join([str(i) for i in output]) for output in outputs]) + table_writer.write(outputs + '\n') + + def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs): + reserve_vals = [outputs[x] for x in output_cols] + \ + [all_vals[k] for k in reserved_cols] + return reserve_vals + + @property + def out_of_range_exception(self): + return (tf.errors.OutOfRangeError) + + +class ODPSPredictor(Predictor): + + def __init__(self, + model_path, + fg_json_path=None, + profiling_file=None, + all_cols='', + all_col_types=''): + super(ODPSPredictor, self).__init__(model_path, profiling_file, + fg_json_path) + self._all_cols = [x.strip() for x in all_cols.split(',') if x != ''] + self._all_col_types = [ + x.strip() for x in all_col_types.split(',') if x != '' + ] + self._record_defaults = [ + self._get_defaults(col_name, col_type) + for col_name, col_type in zip(self._all_cols, self._all_col_types) + ] + + def _get_reserved_cols(self, reserved_cols): + reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] + return reserved_cols + + def _parse_line(self, *fields): + fields = list(fields) + field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))} + return field_dict + + def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, + slice_id): + input_list = input_path.split(',') + dataset = tf.data.TableRecordDataset( + input_list, + record_defaults=self._record_defaults, + slice_id=slice_id, + slice_count=slice_num, + selected_cols=','.join(self._all_cols)) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(buffer_size=64) + return dataset + + def _get_writer(self, output_path, slice_id): + import common_io + table_writer = common_io.table.TableWriter(output_path, slice_id=slice_id) + return table_writer + + def _write_line(self, table_writer, outputs): + assert len(outputs) > 0 + indices = list(range(0, len(outputs[0]))) + table_writer.write(outputs, indices, allow_type_cast=False) + + @property + def out_of_range_exception(self): + return (tf.python_io.OutOfRangeException, tf.errors.OutOfRangeError) + + def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs): + reserve_vals = [all_vals[k] for k in reserved_cols] + \ + [outputs[x] for x in output_cols] + return reserve_vals + + +class HivePredictor(Predictor): + + def __init__(self, + model_path, + data_config, + hive_config, + fg_json_path=None, + profiling_file=None, + selected_cols='', + output_sep=chr(1)): + super(HivePredictor, self).__init__(model_path, profiling_file, + fg_json_path) + + self._data_config = data_config + self._hive_config = hive_config + self._eval_batch_size = data_config.eval_batch_size + self._fetch_size = self._hive_config.fetch_size + self._output_sep = output_sep + self._record_defaults = [ + self._get_defaults(col_name) for col_name in self._input_fields + ] + input_type = DatasetConfig.InputType.Name(data_config.input_type).lower() + + if 'rtp' in input_type: + self._is_rtp = True + else: + self._is_rtp = False + if selected_cols: + self._selected_cols = [int(x) for x in selected_cols.split(',')] + else: + self._selected_cols = None + + def _get_reserved_cols(self, reserved_cols): + if reserved_cols == 'ALL_COLUMNS': + if self._is_rtp: + idx = 0 + reserved_cols = [] + for x in range(len(self._record_defaults) - 1): + if not self._selected_cols or x in self._selected_cols[:-1]: + reserved_cols.append(self._input_fields[idx]) + idx += 1 + else: + reserved_cols.append('no_used_%d' % x) + reserved_cols.append(SINGLE_PLACEHOLDER_FEATURE_KEY) + else: + reserved_cols = self._input_fields + else: + reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] + return reserved_cols + + def _parse_line(self, *fields): + fields = list(fields) + field_dict = {self._input_fields[i]: fields[i] for i in range(len(fields))} + return field_dict + + def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, + slice_id): + _hive_read = HiveUtils( + data_config=self._data_config, + hive_config=self._hive_config, + selected_cols=','.join(self._input_fields), + record_defaults=self._record_defaults, + input_path=input_path, + mode=tf.estimator.ModeKeys.PREDICT, + task_index=slice_id, + task_num=slice_num)._hive_read + + _input_field_types = [x.input_type for x in self._data_config.input_fields] + + list_type = [get_tf_type(x) for x in _input_field_types] + list_type = tuple(list_type) + list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] + list_shapes = tuple(list_shapes) + + dataset = tf.data.Dataset.from_generator( + _hive_read, output_types=list_type, output_shapes=list_shapes) + + return dataset + + def _get_writer(self, output_path, slice_id): + if not gfile.Exists(output_path): + gfile.MakeDirs(output_path) + res_path = os.path.join(output_path, 'slice_%d.csv' % slice_id) + table_writer = gfile.GFile(res_path, 'w') + table_writer.write( + self._output_sep.join(self._output_cols + self._reserved_cols) + '\n') + return table_writer + + def _write_line(self, table_writer, outputs): + outputs = '\n'.join( + [self._output_sep.join([str(i) for i in output]) for output in outputs]) + table_writer.write(outputs + '\n') + + def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs): + reserve_vals = [outputs[x] for x in output_cols] + \ + [all_vals[k] for k in reserved_cols] + return reserve_vals + + @property + def out_of_range_exception(self): + return (tf.errors.OutOfRangeError) diff --git a/easy_rec/python/input/batch_tfrecord_input.py b/easy_rec/python/input/batch_tfrecord_input.py index dacfba6e0..b4c0e4d00 100644 --- a/easy_rec/python/input/batch_tfrecord_input.py +++ b/easy_rec/python/input/batch_tfrecord_input.py @@ -5,6 +5,7 @@ import tensorflow as tf from easy_rec.python.input.input import Input +from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -25,8 +26,9 @@ def __init__(self, task_index=0, task_num=1, check_mode=False): - super(BatchTFRecordInput, self).__init__(data_config, feature_config, input_path, - task_index, task_num, check_mode) + super(BatchTFRecordInput, + self).__init__(data_config, feature_config, input_path, task_index, + task_num, check_mode) assert data_config.HasField( 'n_data_batch_tfrecord'), 'Need to set n_data_batch_tfrecord in config.' self._input_shapes = [x.input_shape for x in data_config.input_fields] @@ -34,7 +36,7 @@ def __init__(self, for x, t, d, s in zip(self._input_fields, self._input_field_types, self._input_field_defaults, self._input_shapes): d = self.get_type_defaults(t, d) - t = self.get_tf_type(t) + t = get_tf_type(t) self.feature_desc[x] = tf.io.FixedLenSequenceFeature( dtype=t, shape=s, allow_missing=True) diff --git a/easy_rec/python/input/csv_input.py b/easy_rec/python/input/csv_input.py index 43278db78..60ede4f90 100644 --- a/easy_rec/python/input/csv_input.py +++ b/easy_rec/python/input/csv_input.py @@ -46,10 +46,14 @@ def _parse_csv(self, line): else: record_defaults.append('') - check_list = [tf.py_func(check_split, - [line, self._data_config.separator, len(record_defaults), self._check_mode], - Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_split, [ + line, self._data_config.separator, + len(record_defaults), self._check_mode + ], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): fields = tf.decode_csv( line, diff --git a/easy_rec/python/input/csv_input_v2.py b/easy_rec/python/input/csv_input_v2.py index df1a90c86..99b42f225 100644 --- a/easy_rec/python/input/csv_input_v2.py +++ b/easy_rec/python/input/csv_input_v2.py @@ -20,7 +20,8 @@ def __init__(self, def _build(self, mode, params): if type(self._input_path) != list: self._input_path = self._input_path.split(',') - assert len(self._input_path) > 0, 'match no files with %s' % self._input_path + assert len( + self._input_path) > 0, 'match no files with %s' % self._input_path if self._input_path[0].startswith('hdfs://'): # support hdfs input diff --git a/easy_rec/python/input/datahub_input.py b/easy_rec/python/input/datahub_input.py index 059afda27..81401f015 100644 --- a/easy_rec/python/input/datahub_input.py +++ b/easy_rec/python/input/datahub_input.py @@ -8,6 +8,7 @@ from easy_rec.python.input.input import Input from easy_rec.python.utils import odps_util +from easy_rec.python.utils.tf_utils import get_tf_type try: import common_io @@ -105,7 +106,7 @@ def _datahub_generator(self): def _build(self, mode, params): # get input type - list_type = [self.get_tf_type(x) for x in self._input_field_types] + list_type = [get_tf_type(x) for x in self._input_field_types] list_type = tuple(list_type) list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] list_shapes = tuple(list_shapes) diff --git a/easy_rec/python/input/dummy_input.py b/easy_rec/python/input/dummy_input.py index 82e28a8e9..49ccc362e 100644 --- a/easy_rec/python/input/dummy_input.py +++ b/easy_rec/python/input/dummy_input.py @@ -4,6 +4,7 @@ import tensorflow as tf from easy_rec.python.input.input import Input +from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -42,7 +43,7 @@ def _build(self, mode, params): for field, field_type, def_val in zip(self._input_fields, self._input_field_types, self._input_field_defaults): - tf_type = self.get_tf_type(field_type) + tf_type = get_tf_type(field_type) def_val = self.get_type_defaults(field_type, default_val=def_val) if field in self._input_vals: diff --git a/easy_rec/python/input/hive_input.py b/easy_rec/python/input/hive_input.py index 694559392..5686433e6 100644 --- a/easy_rec/python/input/hive_input.py +++ b/easy_rec/python/input/hive_input.py @@ -1,93 +1,11 @@ # -*- coding: utf-8 -*- -import logging -import numpy as np import tensorflow as tf from easy_rec.python.input.input import Input from easy_rec.python.utils import odps_util - -try: - from pyhive import hive -except ImportError: - logging.warning('pyhive is not installed.') - - -class TableInfo(object): - - def __init__(self, - tablename, - selected_cols, - partition_kv, - hash_fields, - limit_num, - batch_size=16, - task_index=0, - task_num=1, - epoch=1): - self.tablename = tablename - self.selected_cols = selected_cols - self.partition_kv = partition_kv - self.hash_fields = hash_fields - self.limit_num = limit_num - self.task_index = task_index - self.task_num = task_num - self.batch_size = batch_size - self.epoch = epoch - - def gen_sql(self): - part = '' - if self.partition_kv and len(self.partition_kv) > 0: - res = [] - for k, v in self.partition_kv.items(): - res.append('{}={}'.format(k, v)) - part = ' '.join(res) - sql = """select {} - from {}""".format(self.selected_cols, self.tablename) - assert self.hash_fields is not None, 'hash_fields must not be empty' - fields = [ - 'cast({} as string)'.format(key) for key in self.hash_fields.split(',') - ] - str_fields = ','.join(fields) - if not part: - sql += """ - where hash(concat({}))%{}={} - """.format(str_fields, self.task_num, self.task_index) - else: - sql += """ - where {} and hash(concat({}))%{}={} - """.format(part, str_fields, self.task_num, self.task_index) - if self.limit_num is not None and self.limit_num > 0: - sql += ' limit {}'.format(self.limit_num) - return sql - - -class HiveManager(object): - - def __init__(self, host, port, username, info, database='default'): - self.host = host - self.port = port - self.username = username - self.database = database - self.info = info - - def __call__(self): - conn = hive.Connection( - host=self.host, - port=self.port, - username=self.username, - database=self.database) - cursor = conn.cursor() - sql = self.info.gen_sql() - res = [] - for ep in range(self.info.epoch): - cursor.execute(sql) - for result in cursor.fetchall(): - res.append(result) - if len(res) == self.info.batch_size: - yield res - res = [] - pass +from easy_rec.python.utils.hive_utils import HiveUtils +from easy_rec.python.utils.tf_utils import get_tf_type class HiveInput(Input): @@ -104,6 +22,8 @@ def __init__(self, task_index, task_num, check_mode) if input_path is None: return + self._data_config = data_config + self._feature_config = feature_config self._hive_config = input_path self._eval_batch_size = data_config.eval_batch_size self._fetch_size = self._hive_config.fetch_size @@ -111,28 +31,6 @@ def __init__(self, self._num_epoch = data_config.num_epochs self._num_epoch_record = 1 - def _construct_table_info(self, table_name, hash_fields, limit_num): - # sample_table/dt=2014-11-23/name=a - segs = table_name.split('/') - table_name = segs[0].strip() - if len(segs) > 0: - partition_kv = {i.split('=')[0]: i.split('=')[1] for i in segs[1:]} - else: - partition_kv = None - selected_cols = ','.join(self._input_fields) - table_info = TableInfo(table_name, selected_cols, partition_kv, hash_fields, - limit_num, self._data_config.batch_size, - self._task_index, self._task_num, self._num_epoch) - return table_info - - def _construct_hive_connect(self): - conn = hive.Connection( - host=self._hive_config.host, - port=self._hive_config.port, - username=self._hive_config.username, - database=self._hive_config.database) - return conn - def _parse_table(self, *fields): fields = list(fields) inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids} @@ -140,55 +38,6 @@ def _parse_table(self, *fields): inputs[self._input_fields[x]] = fields[x] return inputs - def _hive_read(self): - logging.info('start epoch[%d]' % self._num_epoch_record) - self._num_epoch_record += 1 - table_names = [t for t in str(self._hive_config.table_name).split(',')] - - # check data_config are consistent with odps tables - odps_util.check_input_field_and_types(self._data_config) - - record_defaults = [ - self.get_type_defaults(x, v) - for x, v in zip(self._input_field_types, self._input_field_defaults) - ] - - for table_path in table_names: - table_info = self._construct_table_info(table_path, - self._hive_config.hash_fields, - self._hive_config.limit_num) - batch_size = self.this_batch_size - batch_defaults = [np.array([x] * batch_size) for x in record_defaults] - row_id = 0 - batch_data_np = [x.copy() for x in batch_defaults] - - conn = self._construct_hive_connect() - cursor = conn.cursor() - sql = table_info.gen_sql() - cursor.execute(sql) - - while True: - data = cursor.fetchmany(size=self._fetch_size) - if len(data) == 0: - break - for rows in data: - for col_id in range(len(record_defaults)): - if rows[col_id] not in ['', 'NULL', None]: - batch_data_np[col_id][row_id] = rows[col_id] - else: - batch_data_np[col_id][row_id] = batch_defaults[col_id][row_id] - row_id += 1 - - if row_id >= batch_size: - yield tuple(batch_data_np) - row_id = 0 - - if row_id > 0: - yield tuple([x[:row_id] for x in batch_data_np]) - cursor.close() - conn.close() - logging.info('finish epoch[%d]' % self._num_epoch_record) - def _get_batch_size(self, mode): if mode == tf.estimator.ModeKeys.TRAIN: return self._data_config.batch_size @@ -197,16 +46,29 @@ def _get_batch_size(self, mode): def _build(self, mode, params): # get input type - list_type = [self.get_tf_type(x) for x in self._input_field_types] + list_type = [get_tf_type(x) for x in self._input_field_types] list_type = tuple(list_type) list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] list_shapes = tuple(list_shapes) - # read odps tables - self.this_batch_size = self._get_batch_size(mode) + # check data_config are consistent with odps tables + odps_util.check_input_field_and_types(self._data_config) + record_defaults = [ + self.get_type_defaults(x, v) + for x, v in zip(self._input_field_types, self._input_field_defaults) + ] + _hive_read = HiveUtils( + data_config=self._data_config, + hive_config=self._hive_config, + selected_cols=','.join(self._input_fields), + record_defaults=record_defaults, + input_path=self._hive_config.table_name, + mode=mode, + task_index=self._task_index, + task_num=self._task_num)._hive_read dataset = tf.data.Dataset.from_generator( - self._hive_read, output_types=list_type, output_shapes=list_shapes) + _hive_read, output_types=list_type, output_shapes=list_shapes) if mode == tf.estimator.ModeKeys.TRAIN: dataset = dataset.shuffle( diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index b91d2aa2e..8f1b93b91 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -11,10 +11,12 @@ from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.utils import config_util from easy_rec.python.utils import constant -from easy_rec.python.utils.check_utils import check_split, check_string_to_number +from easy_rec.python.utils.check_utils import check_split +from easy_rec.python.utils.check_utils import check_string_to_number from easy_rec.python.utils.expr_util import get_expression from easy_rec.python.utils.input_utils import get_type_defaults from easy_rec.python.utils.load_class import get_register_class_meta +from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -34,7 +36,7 @@ def __init__(self, check_mode=False): self._data_config = data_config self._check_mode = check_mode - logging.info("check_mode: %s " % self._check_mode) + logging.info('check_mode: %s ' % self._check_mode) # tf.estimator.ModeKeys.*, only available before # calling self._build self._mode = None @@ -139,18 +141,6 @@ def should_stop(self, curr_epoch): total_epoch = 1 return total_epoch is not None and curr_epoch >= total_epoch - def get_tf_type(self, field_type): - type_map = { - DatasetConfig.INT32: tf.int32, - DatasetConfig.INT64: tf.int64, - DatasetConfig.STRING: tf.string, - DatasetConfig.BOOL: tf.bool, - DatasetConfig.FLOAT: tf.float32, - DatasetConfig.DOUBLE: tf.double - } - assert field_type in type_map, 'invalid type: %s' % field_type - return type_map[field_type] - def create_multi_placeholders(self, export_config): """Create multiply placeholders on export, one for each feature. @@ -192,7 +182,7 @@ def create_multi_placeholders(self, export_config): finput = tf.placeholder(tf_type, [None, None], name=placeholder_name) else: ftype = self._input_field_types[fid] - tf_type = self.get_tf_type(ftype) + tf_type = get_tf_type(ftype) logging.info('input_name: %s, dtype: %s' % (input_name, tf_type)) finput = tf.placeholder(tf_type, [None], name=placeholder_name) inputs[input_name] = finput @@ -229,7 +219,7 @@ def create_placeholders(self, export_config): features = {} for tmp_id, fid in enumerate(effective_fids): ftype = self._input_field_types[fid] - tf_type = self.get_tf_type(ftype) + tf_type = get_tf_type(ftype) input_name = self._input_fields[fid] if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]: features[input_name] = tf.string_to_number( @@ -317,8 +307,8 @@ def _preprocess(self, field_dict): elif len(field.get_shape()) == 2: field = tf.squeeze(field, axis=-1) if fc.HasField('kv_separator') and len(fc.input_names) > 1: - assert False, "Tag Feature Error, " \ - "Cannot set kv_separator and multi input_names in one feature config. Feature: %s." % input_0 + assert False, 'Tag Feature Error, ' \ + 'Cannot set kv_separator and multi input_names in one feature config. Feature: %s.' % input_0 parsed_dict[input_0] = tf.string_split(field, fc.separator) if fc.HasField('kv_separator'): indices = parsed_dict[input_0].indices @@ -328,11 +318,13 @@ def _preprocess(self, field_dict): tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2]) tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1] - check_list = [tf.py_func(check_string_to_number, [tmp_vs, input_0], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, [tmp_vs, input_0], Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): tmp_vs = tf.string_to_number( - tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0) + tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0) parsed_dict[input_0] = tf.sparse.SparseTensor( indices, tmp_ks, parsed_dict[input_0].dense_shape) input_wgt = input_0 + '_WEIGHT' @@ -340,13 +332,17 @@ def _preprocess(self, field_dict): indices, tmp_vs, parsed_dict[input_0].dense_shape) self._appended_fields.append(input_wgt) if not fc.HasField('hash_bucket_size'): - check_list = [tf.py_func(check_string_to_number, [parsed_dict[input_0].values, input_0], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, + [parsed_dict[input_0].values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): vals = tf.string_to_number( - parsed_dict[input_0].values, - tf.int32, - name='tag_fea_%s' % input_0) + parsed_dict[input_0].values, + tf.int32, + name='tag_fea_%s' % input_0) parsed_dict[input_0] = tf.sparse.SparseTensor( parsed_dict[input_0].indices, vals, parsed_dict[input_0].dense_shape) @@ -356,16 +352,21 @@ def _preprocess(self, field_dict): if len(field.get_shape()) == 0: field = tf.expand_dims(field, axis=0) field = tf.string_split(field, fc.separator) - check_list = [tf.py_func(check_string_to_number, [field.values, input_1], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, [field.values, input_1], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): field_vals = tf.string_to_number( - field.values, tf.float32, name='tag_wgt_str_2_flt_%s' % input_1) + field.values, + tf.float32, + name='tag_wgt_str_2_flt_%s' % input_1) assert_op = tf.assert_equal( tf.shape(field_vals)[0], tf.shape(parsed_dict[input_0].values)[0], - message="TagFeature Error: The size of %s not equal to the size of %s. Please check input: %s and %s." - % (input_0, input_1, input_0, input_1)) + message='TagFeature Error: The size of %s not equal to the size of %s. Please check input: %s and %s.' + % (input_0, input_1, input_0, input_1)) with tf.control_dependencies([assert_op]): field = tf.sparse.SparseTensor(field.indices, tf.identity(field_vals), @@ -402,27 +403,35 @@ def _preprocess(self, field_dict): parsed_dict[input_0] = tf.sparse.SparseTensor( out_indices, multi_vals.values, out_shape) if (fc.num_buckets > 1 and fc.max_val == fc.min_val): - check_list = [tf.py_func(check_string_to_number, [parsed_dict[input_0].values, input_0], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, + [parsed_dict[input_0].values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): parsed_dict[input_0] = tf.sparse.SparseTensor( - parsed_dict[input_0].indices, - tf.string_to_number( - parsed_dict[input_0].values, - tf.int64, - name='sequence_str_2_int_%s' % input_0), - parsed_dict[input_0].dense_shape) + parsed_dict[input_0].indices, + tf.string_to_number( + parsed_dict[input_0].values, + tf.int64, + name='sequence_str_2_int_%s' % input_0), + parsed_dict[input_0].dense_shape) elif sub_feature_type == fc.RawFeature: - check_list = [tf.py_func(check_string_to_number, [parsed_dict[input_0].values, input_0], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, + [parsed_dict[input_0].values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): parsed_dict[input_0] = tf.sparse.SparseTensor( - parsed_dict[input_0].indices, - tf.string_to_number( - parsed_dict[input_0].values, - tf.float32, - name='sequence_str_2_float_%s' % input_0), - parsed_dict[input_0].dense_shape) + parsed_dict[input_0].indices, + tf.string_to_number( + parsed_dict[input_0].values, + tf.float32, + name='sequence_str_2_float_%s' % input_0), + parsed_dict[input_0].dense_shape) if fc.num_buckets > 1 and fc.max_val > fc.min_val: normalized_values = (parsed_dict[input_0].values - fc.min_val) / ( fc.max_val - fc.min_val) @@ -509,30 +518,40 @@ def _preprocess(self, field_dict): input_0 = fc.input_names[0] if field_dict[input_0].dtype == tf.string: if fc.raw_input_dim > 1: - check_list = [tf.py_func(check_split, - [field_dict[input_0], fc.separator, fc.raw_input_dim, input_0], - Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_split, [ + field_dict[input_0], fc.separator, fc.raw_input_dim, + input_0 + ], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): tmp_fea = tf.string_split(field_dict[input_0], fc.separator) - check_list = [tf.py_func(check_string_to_number, [tmp_fea.values, input_0], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, [tmp_fea.values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): tmp_vals = tf.string_to_number( - tmp_fea.values, - tf.float32, - name='multi_raw_fea_to_flt_%s' % input_0) + tmp_fea.values, + tf.float32, + name='multi_raw_fea_to_flt_%s' % input_0) parsed_dict[input_0] = tf.sparse_to_dense( tmp_fea.indices, [tf.shape(field_dict[input_0])[0], fc.raw_input_dim], tmp_vals, default_value=0) else: - check_list = [tf.py_func(check_string_to_number, [field_dict[input_0], input_0], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, [field_dict[input_0], input_0], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): - parsed_dict[input_0] = tf.string_to_number(field_dict[input_0], - tf.float32) + parsed_dict[input_0] = tf.string_to_number( + field_dict[input_0], tf.float32) elif field_dict[input_0].dtype in [ tf.int32, tf.int64, tf.double, tf.float32 ]: @@ -593,26 +612,39 @@ def _preprocess(self, field_dict): field_dict[input_0], precision=precision) elif fc.num_buckets > 0: if parsed_dict[input_0].dtype == tf.string: - check_list = [tf.py_func(check_string_to_number, [parsed_dict[input_0], input_0], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, [parsed_dict[input_0], input_0], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): parsed_dict[input_0] = tf.string_to_number( - parsed_dict[input_0], tf.int32, name='%s_str_2_int' % input_0) + parsed_dict[input_0], tf.int32, name='%s_str_2_int' % input_0) elif feature_type == fc.ExprFeature: fea_name = fc.feature_name prefix = 'expr_' for input_name in fc.input_names: - new_input_name = prefix + input_name - if field_dict[input_name].dtype == tf.string: - check_list = [tf.py_func(check_string_to_number, [field_dict[input_name], input_name], Tout=tf.bool) - ] if self._check_mode else [] - with tf.control_dependencies(check_list): - parsed_dict[new_input_name] = tf.string_to_number( - field_dict[input_name], tf.float64, name='%s_str_2_int_for_expr' % new_input_name) - elif field_dict[input_name].dtype in [tf.int32, tf.int64, tf.double, tf.float32]: - parsed_dict[new_input_name] = tf.cast(field_dict[input_name], tf.float64) - else: - assert False, 'invalid input dtype[%s] for expr feature' % str(field_dict[input_name].dtype) + new_input_name = prefix + input_name + if field_dict[input_name].dtype == tf.string: + check_list = [ + tf.py_func( + check_string_to_number, + [field_dict[input_name], input_name], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[new_input_name] = tf.string_to_number( + field_dict[input_name], + tf.float64, + name='%s_str_2_int_for_expr' % new_input_name) + elif field_dict[input_name].dtype in [ + tf.int32, tf.int64, tf.double, tf.float32 + ]: + parsed_dict[new_input_name] = tf.cast(field_dict[input_name], + tf.float64) + else: + assert False, 'invalid input dtype[%s] for expr feature' % str( + field_dict[input_name].dtype) expression = get_expression( fc.expression, fc.input_names, prefix=prefix) @@ -629,23 +661,29 @@ def _preprocess(self, field_dict): if field_dict[input_name].dtype == tf.string: if self._label_dim[input_id] > 1: logging.info('will split labels[%d]=%s' % (input_id, input_name)) - check_list = [tf.py_func(check_split, - [field_dict[input_name], self._label_sep[input_id], - self._label_dim[input_id], input_name], - Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_split, [ + field_dict[input_name], self._label_sep[input_id], + self._label_dim[input_id], input_name + ], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): parsed_dict[input_name] = tf.string_split( - field_dict[input_name], self._label_sep[input_id]).values - parsed_dict[input_name] = tf.reshape(parsed_dict[input_name], - [-1, self._label_dim[input_id]]) + field_dict[input_name], self._label_sep[input_id]).values + parsed_dict[input_name] = tf.reshape( + parsed_dict[input_name], [-1, self._label_dim[input_id]]) else: parsed_dict[input_name] = field_dict[input_name] - check_list = [tf.py_func(check_string_to_number, [parsed_dict[input_name], input_name], Tout=tf.bool) - ] if self._check_mode else [] + check_list = [ + tf.py_func( + check_string_to_number, [parsed_dict[input_name], input_name], + Tout=tf.bool) + ] if self._check_mode else [] with tf.control_dependencies(check_list): parsed_dict[input_name] = tf.string_to_number( - parsed_dict[input_name], tf.float32, name=input_name) + parsed_dict[input_name], tf.float32, name=input_name) else: assert field_dict[input_name].dtype in [ tf.float32, tf.double, tf.int32, tf.int64 diff --git a/easy_rec/python/input/odps_input.py b/easy_rec/python/input/odps_input.py index a6c4c55b4..5bda43763 100644 --- a/easy_rec/python/input/odps_input.py +++ b/easy_rec/python/input/odps_input.py @@ -44,7 +44,8 @@ def _build(self, mode, params): if type(self._input_path) != list: self._input_path = self._input_path.split(',') - assert len(self._input_path) > 0, 'match no files with %s' % self._input_path + assert len( + self._input_path) > 0, 'match no files with %s' % self._input_path if mode == tf.estimator.ModeKeys.TRAIN: if self._data_config.pai_worker_queue: diff --git a/easy_rec/python/input/odps_input_v2.py b/easy_rec/python/input/odps_input_v2.py index dddc7abd4..e806e1c30 100644 --- a/easy_rec/python/input/odps_input_v2.py +++ b/easy_rec/python/input/odps_input_v2.py @@ -35,7 +35,8 @@ def _parse_table(self, *fields): def _build(self, mode, params): if type(self._input_path) != list: self._input_path = self._input_path.split(',') - assert len(self._input_path) > 0, 'match no files with %s' % self._input_path + assert len( + self._input_path) > 0, 'match no files with %s' % self._input_path # check data_config are consistent with odps tables odps_util.check_input_field_and_types(self._data_config) diff --git a/easy_rec/python/input/odps_input_v3.py b/easy_rec/python/input/odps_input_v3.py index 824c10771..f263ca3bd 100644 --- a/easy_rec/python/input/odps_input_v3.py +++ b/easy_rec/python/input/odps_input_v3.py @@ -9,6 +9,7 @@ from easy_rec.python.input.input import Input from easy_rec.python.utils import odps_util +from easy_rec.python.utils.tf_utils import get_tf_type try: import common_io @@ -47,7 +48,8 @@ def _odps_read(self): self._num_epoch += 1 if type(self._input_path) != list: self._input_path = self._input_path.split(',') - assert len(self._input_path) > 0, 'match no files with %s' % self._input_path + assert len( + self._input_path) > 0, 'match no files with %s' % self._input_path # check data_config are consistent with odps tables odps_util.check_input_field_and_types(self._data_config) @@ -90,7 +92,7 @@ def _odps_read(self): def _build(self, mode, params): # get input type - list_type = [self.get_tf_type(x) for x in self._input_field_types] + list_type = [get_tf_type(x) for x in self._input_field_types] list_type = tuple(list_type) list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] list_shapes = tuple(list_shapes) diff --git a/easy_rec/python/input/rtp_input.py b/easy_rec/python/input/rtp_input.py index 425b6b1a9..4fcfdf2d4 100644 --- a/easy_rec/python/input/rtp_input.py +++ b/easy_rec/python/input/rtp_input.py @@ -5,10 +5,10 @@ import tensorflow as tf from easy_rec.python.input.input import Input -from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.utils.check_utils import check_split from easy_rec.python.utils.check_utils import check_string_to_number from easy_rec.python.utils.input_utils import string_to_number +from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -78,7 +78,7 @@ def _parse_csv(self, line): field = fields[:, x] fname = self._input_fields[idx] ftype = self._input_field_types[idx] - tf_type = self.get_tf_type(ftype) + tf_type = get_tf_type(ftype) if field.dtype in [tf.string]: check_list = [ tf.py_func(check_string_to_number, [field, fname], Tout=tf.bool) diff --git a/easy_rec/python/input/tfrecord_input.py b/easy_rec/python/input/tfrecord_input.py index fda9cea6f..0447b25bb 100644 --- a/easy_rec/python/input/tfrecord_input.py +++ b/easy_rec/python/input/tfrecord_input.py @@ -5,6 +5,7 @@ import tensorflow as tf from easy_rec.python.input.input import Input +from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -26,7 +27,7 @@ def __init__(self, for x, t, d in zip(self._input_fields, self._input_field_types, self._input_field_defaults): d = self.get_type_defaults(t, d) - t = self.get_tf_type(t) + t = get_tf_type(t) self.feature_desc[x] = tf.FixedLenFeature( dtype=t, shape=1, default_value=d) diff --git a/easy_rec/python/main.py b/easy_rec/python/main.py index cd773dcac..0a68b12b5 100644 --- a/easy_rec/python/main.py +++ b/easy_rec/python/main.py @@ -25,6 +25,9 @@ from easy_rec.python.utils import estimator_utils from easy_rec.python.utils import fg_util from easy_rec.python.utils import load_class +from easy_rec.python.utils.config_util import get_eval_input_path +from easy_rec.python.utils.config_util import get_train_input_path +from easy_rec.python.utils.config_util import set_eval_input_path from easy_rec.python.utils.export_big_model import export_big_model from easy_rec.python.utils.export_big_model import export_big_model_to_oss @@ -253,17 +256,6 @@ def train_and_evaluate(pipeline_config_path, continue_train=False): return pipeline_config -def _get_input_object_by_name(pipeline_config, worker_type): - """" get object by worker type. - - pipeline_config: pipeline_config - worker_type: train or eval - """ - input_type = '{}_path'.format(worker_type) - input_name = pipeline_config.WhichOneof(input_type) - return getattr(pipeline_config, input_name) - - def _train_and_evaluate_impl(pipeline_config, continue_train=False, check_mode=False): @@ -278,8 +270,8 @@ def _train_and_evaluate_impl(pipeline_config, % pipeline_config.train_config.train_distribute) pipeline_config.train_config.sync_replicas = False - train_data = _get_input_object_by_name(pipeline_config, 'train') - eval_data = _get_input_object_by_name(pipeline_config, 'eval') + train_data = get_train_input_path(pipeline_config) + eval_data = get_eval_input_path(pipeline_config) distribution = strategy_builder.build(train_config) estimator, run_config = _create_estimator( @@ -298,7 +290,7 @@ def _train_and_evaluate_impl(pipeline_config, train_steps = None if train_config.HasField('num_steps'): train_steps = train_config.num_steps - assert train_steps is not None or data_config.num_epochs > 0, "either num_steps and num_epochs must be set to an integer > 0." + assert train_steps is not None or data_config.num_epochs > 0, 'either num_steps and num_epochs must be set to an integer > 0.' if train_steps and data_config.num_epochs: logging.info('Both num_steps and num_epochs are set.') @@ -363,12 +355,10 @@ def evaluate(pipeline_config, fg_util.load_fg_json_to_config(pipeline_config) if eval_data_path is not None: logging.info('Evaluating on data: %s' % eval_data_path) - if isinstance(eval_data_path, list): - pipeline_config.eval_input_path = ','.join(eval_data_path) - else: - pipeline_config.eval_input_path = eval_data_path + set_eval_input_path(pipeline_config, eval_data_path) + train_config = pipeline_config.train_config - eval_data = _get_input_object_by_name(pipeline_config, 'eval') + eval_data = get_eval_input_path(pipeline_config) server_target = None if 'TF_CONFIG' in os.environ: @@ -483,13 +473,9 @@ def distribute_evaluate(pipeline_config, pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config) if eval_data_path is not None: logging.info('Evaluating on data: %s' % eval_data_path) - if isinstance(eval_data_path, list): - pipeline_config.eval_input_path = ','.join(eval_data_path) - else: - pipeline_config.eval_input_path = eval_data_path + set_eval_input_path(pipeline_config, eval_data_path) train_config = pipeline_config.train_config - - eval_data = _get_input_object_by_name(pipeline_config, 'eval') + eval_data = get_eval_input_path(pipeline_config) server_target = None cur_job_name = None @@ -642,12 +628,9 @@ def predict(pipeline_config, checkpoint_path='', data_path=None): fg_util.load_fg_json_to_config(pipeline_config) if data_path is not None: logging.info('Predict on data: %s' % data_path) - pipeline_config.eval_input_path = data_path + set_eval_input_path(pipeline_config, data_path) train_config = pipeline_config.train_config - if pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input': - eval_data = pipeline_config.kafka_eval_input - else: - eval_data = pipeline_config.eval_input_path + eval_data = get_eval_input_path(pipeline_config) distribution = strategy_builder.build(train_config) estimator, _ = _create_estimator(pipeline_config, distribution) diff --git a/easy_rec/python/model/mind.py b/easy_rec/python/model/mind.py index 2efdd2bbf..318f6a886 100644 --- a/easy_rec/python/model/mind.py +++ b/easy_rec/python/model/mind.py @@ -254,7 +254,8 @@ def build_loss_graph(self): loss_dict = super(MIND, self).build_loss_graph() if self._model_config.max_interests_simi < 1.0: loss_dict['reg_interest_simi'] = tf.nn.relu( - self._prediction_dict['interests_simi'] - self._model_config.max_interests_simi) + self._prediction_dict['interests_simi'] - + self._model_config.max_interests_simi) return loss_dict def _build_interest_simi(self): diff --git a/easy_rec/python/predict.py b/easy_rec/python/predict.py index 46633708e..31d9d443d 100644 --- a/easy_rec/python/predict.py +++ b/easy_rec/python/predict.py @@ -8,8 +8,10 @@ import tensorflow as tf from tensorflow.python.lib.io import file_io -from easy_rec.python.inference.predictor import Predictor +from easy_rec.python.inference.predictor import CSVPredictor +from easy_rec.python.inference.predictor import HivePredictor from easy_rec.python.main import predict +from easy_rec.python.utils import config_util if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -18,9 +20,7 @@ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s', level=logging.INFO) -tf.app.flags.DEFINE_string( - 'input_path', None, 'predict data path, if specified will ' - 'override pipeline_config.eval_input_path') +tf.app.flags.DEFINE_string('input_path', None, 'predict data path') tf.app.flags.DEFINE_string('output_path', None, 'path to save predict result') tf.app.flags.DEFINE_integer('batch_size', 1024, help='batch size') @@ -45,7 +45,8 @@ tf.app.flags.DEFINE_string('input_sep', ',', 'separator of predict result file') tf.app.flags.DEFINE_string('output_sep', chr(1), 'separator of predict result file') - +tf.app.flags.DEFINE_string('selected_cols', '', '') +tf.app.flags.DEFINE_string('fg_json', '', '') FLAGS = tf.app.flags.FLAGS @@ -53,7 +54,25 @@ def main(argv): if FLAGS.saved_model_dir: logging.info('Predict by saved_model.') - predictor = Predictor(FLAGS.saved_model_dir) + pipeline_config = config_util.get_configs_from_pipeline_file( + FLAGS.pipeline_config_path, False) + if pipeline_config.WhichOneof('train_path') == 'hive_train_input': + predictor = HivePredictor( + FLAGS.saved_model_dir, + pipeline_config.data_config, + fg_json_path=FLAGS.fg_json_path, + hive_config=pipeline_config.hive_train_input, + selected_cols=FLAGS.selected_cols, + output_sep=FLAGS.output_sep) + else: + predictor = CSVPredictor( + FLAGS.saved_model_dir, + pipeline_config.data_config, + fg_json_path=FLAGS.fg_json_path, + selected_cols=FLAGS.selected_cols, + input_sep=FLAGS.input_sep, + output_sep=FLAGS.output_sep) + logging.info('input_path = %s, output_path = %s' % (FLAGS.input_path, FLAGS.output_path)) if 'TF_CONFIG' in os.environ: @@ -68,10 +87,9 @@ def main(argv): FLAGS.output_path, reserved_cols=FLAGS.reserved_cols, output_cols=FLAGS.output_cols, + batch_size=FLAGS.batch_size, slice_id=task_index, - slice_num=worker_num, - input_sep=FLAGS.input_sep, - output_sep=FLAGS.output_sep) + slice_num=worker_num) else: logging.info('Predict by checkpoint_path.') assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.' diff --git a/easy_rec/python/test/hive_input_test.py b/easy_rec/python/test/hive_input_test.py index 9e9fea9d1..c7a8c58db 100644 --- a/easy_rec/python/test/hive_input_test.py +++ b/easy_rec/python/test/hive_input_test.py @@ -1,6 +1,7 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """Define cv_input, the base class for cv tasks.""" +import logging import os import unittest @@ -9,11 +10,11 @@ from easy_rec.python.input.hive_input import HiveInput from easy_rec.python.protos.dataset_pb2 import DatasetConfig +from easy_rec.python.protos.feature_config_pb2 import FeatureConfig from easy_rec.python.protos.hive_config_pb2 import HiveConfig +from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig +from easy_rec.python.utils import config_util from easy_rec.python.utils import test_utils -from easy_rec.python.utils.config_util import * -from easy_rec.python.utils.test_utils import * -from easy_rec.python.utils.test_utils import _load_config_for_test if tf.__version__ >= '2.0': #tf = tf.compat.v1 @@ -261,11 +262,11 @@ def test_hive_input(self): hive_username、hive_table_name、hive_hash_fields is available.""") def test_mmoe(self): pipeline_config_path = 'samples/emr_script/mmoe/mmoe_census_income.config' - gpus = get_available_gpus() + gpus = test_utils.get_available_gpus() if len(gpus) > 0: - set_gpu_id(gpus[0]) + test_utils.set_gpu_id(gpus[0]) else: - set_gpu_id(None) + test_utils.set_gpu_id(None) if not isinstance(pipeline_config_path, EasyRecConfig): logging.info('testing pipeline config %s' % pipeline_config_path) @@ -275,8 +276,8 @@ def test_mmoe(self): if isinstance(pipeline_config_path, EasyRecConfig): pipeline_config = pipeline_config_path else: - pipeline_config = _load_config_for_test(pipeline_config_path, - self._test_dir) + pipeline_config = test_utils._load_config_for_test( + pipeline_config_path, self._test_dir) pipeline_config.train_config.train_distribute = 0 pipeline_config.train_config.num_gpus_per_worker = 1 @@ -287,7 +288,8 @@ def test_mmoe(self): hyperparam_str = '' train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s %s' % ( test_pipeline_config_path, hyperparam_str) - proc = run_cmd(train_cmd, '%s/log_%s.txt' % (self._test_dir, 'master')) + proc = test_utils.run_cmd(train_cmd, + '%s/log_%s.txt' % (self._test_dir, 'master')) proc.wait() if proc.returncode != 0: logging.error('train %s failed' % test_pipeline_config_path) diff --git a/easy_rec/python/test/pre_check_test.py b/easy_rec/python/test/pre_check_test.py index dcadd33af..58b295157 100644 --- a/easy_rec/python/test/pre_check_test.py +++ b/easy_rec/python/test/pre_check_test.py @@ -1,19 +1,10 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -import glob import logging -import os -import sys -import unittest -from distutils.version import LooseVersion -import numpy as np import tensorflow as tf -from easy_rec.python.main import predict -from easy_rec.python.utils import config_util -from easy_rec.python.utils import estimator_utils from easy_rec.python.utils import test_utils if tf.__version__ >= '2.0': @@ -24,9 +15,9 @@ class CheckTest(tf.test.TestCase): def setUp(self): - logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName)) self._test_dir = test_utils.get_tmp_dir() self._success = True + logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName)) logging.info('test dir: %s' % self._test_dir) def tearDown(self): diff --git a/easy_rec/python/test/predictor_test.py b/easy_rec/python/test/predictor_test.py index bbd496d28..6c6889cb3 100644 --- a/easy_rec/python/test/predictor_test.py +++ b/easy_rec/python/test/predictor_test.py @@ -9,7 +9,9 @@ import numpy as np import tensorflow as tf +from easy_rec.python.inference.predictor import CSVPredictor from easy_rec.python.inference.predictor import Predictor +from easy_rec.python.utils import config_util from easy_rec.python.utils import test_utils from easy_rec.python.utils.test_utils import RunAsSubprocess @@ -123,36 +125,138 @@ def test_fm_pred_dict(self): class PredictorTestOnDS(tf.test.TestCase): def setUp(self): - self._test_input_path = 'data/test/inference/taobao_infer_data.txt' - self._test_dir = test_utils.get_tmp_dir() - self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result') - self.gpus = test_utils.get_available_gpus() - self.assertTrue(len(self.gpus) > 0, 'no available gpu on this machine') - logging.info('available gpus %s' % self.gpus) - test_utils.set_gpu_id(self.gpus[0]) + self._test_dir = test_utils.get_tmp_dir() + self._test_output_path = None logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName)) def tearDown(self): - if (os.path.exists(self._test_output_path)): + if self._test_output_path and (os.path.exists(self._test_output_path)): shutil.rmtree(self._test_output_path) test_utils.set_gpu_id(None) + @RunAsSubprocess def test_local_pred(self): - predictor = Predictor('data/test/inference/tb_multitower_export/') + test_input_path = 'data/test/inference/taobao_infer_data.txt' + self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result') + save_model_dir = 'data/test/inference/tb_multitower_export/' + pipeline_config_path = os.path.join(save_model_dir, + 'assets/pipeline.config') + pipeline_config = config_util.get_configs_from_pipeline_file( + pipeline_config_path, False) + predictor = CSVPredictor( + save_model_dir, + pipeline_config.data_config, + input_sep=',', + output_sep=';', + selected_cols='') + predictor.predict_impl( - self._test_input_path, + test_input_path, self._test_output_path, reserved_cols='ALL_COLUMNS', output_cols='ALL_COLUMNS', slice_id=0, - slice_num=1, + slice_num=1) + header_truth = 'logits;probs;clk;buy;pid;adgroup_id;cate_id;campaign_id;customer;'\ + 'brand;user_id;cms_segid;cms_group_id;final_gender_code;age_level;pvalue_level;' \ + 'shopping_level;occupation;new_user_class_level;tag_category_list;tag_brand_list;price' + + with open(self._test_output_path + '/slice_0.csv', 'r') as f: + output_res = f.readlines() + self.assertTrue(len(output_res) == 101) + self.assertEqual(output_res[0].strip(), header_truth) + + @RunAsSubprocess + def test_local_pred_with_part_col(self): + test_input_path = 'data/test/inference/taobao_infer_data.txt' + self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result') + save_model_dir = 'data/test/inference/tb_multitower_export/' + pipeline_config_path = os.path.join(save_model_dir, + 'assets/pipeline.config') + pipeline_config = config_util.get_configs_from_pipeline_file( + pipeline_config_path, False) + + predictor = CSVPredictor( + save_model_dir, + pipeline_config.data_config, input_sep=',', - output_sep=';') + output_sep=';', + selected_cols='') + predictor.predict_impl( + test_input_path, + self._test_output_path, + reserved_cols='clk,buy,user_id,adgroup_id', + output_cols='probs', + slice_id=0, + slice_num=1) + header_truth = 'probs;clk;buy;user_id;adgroup_id' + + with open(self._test_output_path + '/slice_0.csv', 'r') as f: + output_res = f.readlines() + self.assertTrue(len(output_res) == 101) + self.assertEqual(output_res[0].strip(), header_truth) + + @RunAsSubprocess + def test_local_pred_rtp(self): + test_input_path = 'data/test/inference/taobao_infer_rtp_data.txt' + self._test_output_path = os.path.join(self._test_dir, + 'taobao_test_feature_result') + save_model_dir = 'data/test/inference/tb_multitower_rtp_export/' + pipeline_config_path = os.path.join(save_model_dir, + 'assets/pipeline.config') + pipeline_config = config_util.get_configs_from_pipeline_file( + pipeline_config_path, False) + + predictor = CSVPredictor( + save_model_dir, + pipeline_config.data_config, + input_sep=';', + output_sep=';', + selected_cols='0,3') + predictor.predict_impl( + test_input_path, + self._test_output_path, + reserved_cols='ALL_COLUMNS', + output_cols='ALL_COLUMNS', + slice_id=0, + slice_num=1) + header_truth = 'logits;probs;clk;no_used_1;no_used_2;features' + with open(self._test_output_path + '/slice_0.csv', 'r') as f: + output_res = f.readlines() + self.assertTrue(len(output_res) == 101) + self.assertEqual(output_res[0].strip(), header_truth) + + @RunAsSubprocess + def test_local_pred_rtp_with_part_col(self): + test_input_path = 'data/test/inference/taobao_infer_rtp_data.txt' + self._test_output_path = os.path.join(self._test_dir, + 'taobao_test_feature_result') + save_model_dir = 'data/test/inference/tb_multitower_rtp_export/' + pipeline_config_path = os.path.join(save_model_dir, + 'assets/pipeline.config') + pipeline_config = config_util.get_configs_from_pipeline_file( + pipeline_config_path, False) + + predictor = CSVPredictor( + save_model_dir, + pipeline_config.data_config, + input_sep=';', + output_sep=';', + selected_cols='0,3') + predictor.predict_impl( + test_input_path, + self._test_output_path, + reserved_cols='clk,features,no_used_1', + output_cols='ALL_COLUMNS', + slice_id=0, + slice_num=1) + header_truth = 'logits;probs;clk;features;no_used_1' with open(self._test_output_path + '/slice_0.csv', 'r') as f: output_res = f.readlines() self.assertTrue(len(output_res) == 101) + self.assertEqual(output_res[0].strip(), header_truth) class PredictorTestV2(tf.test.TestCase): diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 60797898b..ecf3c76be 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -25,6 +25,7 @@ if tf.__version__ >= '2.0': tf = tf.compat.v1 + class TrainEvalTest(tf.test.TestCase): def setUp(self): diff --git a/easy_rec/python/train_eval.py b/easy_rec/python/train_eval.py index f523a41c2..2595d8433 100644 --- a/easy_rec/python/train_eval.py +++ b/easy_rec/python/train_eval.py @@ -12,6 +12,8 @@ from easy_rec.python.utils import estimator_utils from easy_rec.python.utils import fg_util from easy_rec.python.utils import hpo_util +from easy_rec.python.utils.config_util import set_eval_input_path +from easy_rec.python.utils.config_util import set_train_input_path from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA @@ -64,13 +66,9 @@ def main(argv): pipeline_config.model_dir = FLAGS.model_dir logging.info('update model_dir to %s' % pipeline_config.model_dir) if FLAGS.train_input_path: - pipeline_config.train_input_path = ','.join(FLAGS.train_input_path) - logging.info('update train_input_path to %s' % - pipeline_config.train_input_path) + set_train_input_path(pipeline_config, FLAGS.train_input_path) if FLAGS.eval_input_path: - pipeline_config.eval_input_path = ','.join(FLAGS.eval_input_path) - logging.info('update eval_input_path to %s' % - pipeline_config.eval_input_path) + set_eval_input_path(pipeline_config, FLAGS.eval_input_path) if FLAGS.fine_tune_checkpoint: if file_io.file_exists(FLAGS.fine_tune_checkpoint): pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint diff --git a/easy_rec/python/utils/check_utils.py b/easy_rec/python/utils/check_utils.py index b635a1915..cc9315654 100644 --- a/easy_rec/python/utils/check_utils.py +++ b/easy_rec/python/utils/check_utils.py @@ -12,20 +12,17 @@ def check_split(line, sep, requried_field_num, field_name=''): assert sep, 'must have separator.' + (' field: %s.' % field_name) if field_name else '' - # if isinstance(sep, bytes): - # sep = bytes.decode(sep) - # elif type(sep) != type(str): - # sep = str(sep).encode('utf-8') + for one_line in line: field_num = len(one_line.split(sep)) if field_name: assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, field: %s, value: %s, ' \ - 'please check separator and data.' % \ - (sep, field_num, requried_field_num, field_name, one_line) + 'please check separator and data.' % \ + (sep, field_num, requried_field_num, field_name, one_line) else: assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, current line is: %s, ' \ - 'please check separator and data.' % \ - (sep, field_num, requried_field_num, one_line) + 'please check separator and data.' % \ + (sep, field_num, requried_field_num, one_line) assert field_num == requried_field_num, assert_info return True @@ -36,7 +33,7 @@ def check_string_to_number(field_vals, field_name): float(val) except: assert False, 'StringToNumber ERROR: cannot convert string_to_number, field: %s, value: %s. ' \ - 'please check data.' % (field_name, val) + 'please check data.' % (field_name, val) return True @@ -50,14 +47,14 @@ def check_sequence(pipeline_config_path, features): return for seq_att_map in seq_att_maps: assert len(seq_att_map.key) == len(seq_att_map.hist_seq), \ - 'The size of hist_seq must equal to the size of key in one seq_att_map.' + 'The size of hist_seq must equal to the size of key in one seq_att_map.' size_list = [] for hist_seq in seq_att_map.hist_seq: cur_seq_size = len(features[hist_seq].values) size_list.append(cur_seq_size) hist_seqs = ' '.join(seq_att_map.hist_seq) assert len(set(size_list)) == 1, \ - 'SequenceFeature Error: The size in [%s] should be consistent. Please check input: [%s].' % \ + 'SequenceFeature Error: The size in [%s] should be consistent. Please check input: [%s].' % \ (hist_seqs, hist_seqs) @@ -75,7 +72,7 @@ def check_env_and_input_path(pipeline_config, input_path): if input_type in ignore_input_list: return True assert_info = 'Current InputType is %s, InputPath is %s. Please check InputType and InputPath.' % \ - (input_type_name, input_path) + (input_type_name, input_path) if input_type_name.startswith('Odps'): # is on pai for path in input_path.split(','): diff --git a/easy_rec/python/utils/config_util.py b/easy_rec/python/utils/config_util.py index 38ce860d4..65a0df56f 100644 --- a/easy_rec/python/utils/config_util.py +++ b/easy_rec/python/utils/config_util.py @@ -354,3 +354,88 @@ def get_compatible_feature_configs(pipeline_config): else: feature_configs = pipeline_config.feature_config.features return feature_configs + + +def get_input_name_from_fg_json(fg_json): + if not fg_json: + return [] + input_names = [] + for fea in fg_json['features']: + if 'feature_name' in fea: + input_names.append(fea['feature_name']) + elif 'sequence_name' in fea: + sequence_name = fea['sequence_name'] + for seq_fea in fea['features']: + assert 'feature_name' in seq_fea + feature_name = seq_fea['feature_name'] + input_names.append(sequence_name + '__' + feature_name) + return input_names + + +def get_train_input_path(pipeline_config): + input_name = pipeline_config.WhichOneof('train_path') + return getattr(pipeline_config, input_name) + + +def get_eval_input_path(pipeline_config): + input_name = pipeline_config.WhichOneof('eval_path') + return getattr(pipeline_config, input_name) + + +def set_train_input_path(pipeline_config, train_input_path): + if pipeline_config.WhichOneof('train_path') == 'hive_train_input': + if isinstance(train_input_path, list): + assert len( + train_input_path + ) <= 1, 'only support one hive_train_input.table_name when hive input' + pipeline_config.hive_train_input.table_name = train_input_path[0] + else: + assert len( + train_input_path.split(',') + ) <= 1, 'only support one hive_train_input.table_name when hive input' + pipeline_config.hive_train_input.table_name = train_input_path + logging.info('update hive_train_input.table_name to %s' % + pipeline_config.hive_train_input.table_name) + + elif pipeline_config.WhichOneof('train_path') == 'kafka_train_input': + if isinstance(train_input_path, list): + pipeline_config.kafka_train_input = ','.join(train_input_path) + else: + pipeline_config.kafka_train_input = train_input_path + else: + if isinstance(train_input_path, list): + pipeline_config.train_input_path = ','.join(train_input_path) + else: + pipeline_config.train_input_path = train_input_path + logging.info('update train_input_path to %s' % + pipeline_config.train_input_path) + return pipeline_config + + +def set_eval_input_path(pipeline_config, eval_input_path): + if pipeline_config.WhichOneof('eval_path') == 'hive_eval_input': + if isinstance(eval_input_path, list): + assert len( + eval_input_path + ) <= 1, 'only support one hive_eval_input.table_name when hive input' + pipeline_config.hive_eval_input.table_name = eval_input_path[0] + else: + assert len( + eval_input_path.split(',') + ) <= 1, 'only support one hive_eval_input.table_name when hive input' + pipeline_config.hive_eval_input.table_name = eval_input_path + logging.info('update hive_train_input.table_name to %s' % + pipeline_config.hive_eval_input.table_name) + elif pipeline_config.WhichOneof('train_path') == 'kafka_eval_input': + if isinstance(eval_input_path, list): + pipeline_config.kafka_eval_input = ','.join(eval_input_path) + else: + pipeline_config.kafka_eval_input = eval_input_path + else: + if isinstance(eval_input_path, list): + pipeline_config.eval_input_path = ','.join(eval_input_path) + else: + pipeline_config.eval_input_path = eval_input_path + logging.info('update train_input_path to %s' % + pipeline_config.eval_input_path) + return pipeline_config diff --git a/easy_rec/python/utils/hive_utils.py b/easy_rec/python/utils/hive_utils.py new file mode 100644 index 000000000..44cbc2d41 --- /dev/null +++ b/easy_rec/python/utils/hive_utils.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +import logging + +import numpy as np +import tensorflow as tf + +try: + from pyhive import hive +except ImportError: + logging.warning('pyhive is not installed.') + + +class TableInfo(object): + + def __init__(self, + tablename, + selected_cols, + partition_kv, + hash_fields, + limit_num, + batch_size=16, + task_index=0, + task_num=1, + epoch=1): + self.tablename = tablename + self.selected_cols = selected_cols + self.partition_kv = partition_kv + self.hash_fields = hash_fields + self.limit_num = limit_num + self.task_index = task_index + self.task_num = task_num + self.batch_size = batch_size + self.epoch = epoch + + def gen_sql(self): + part = '' + if self.partition_kv and len(self.partition_kv) > 0: + res = [] + for k, v in self.partition_kv.items(): + res.append('{}={}'.format(k, v)) + part = ' '.join(res) + sql = """select {} + from {}""".format(self.selected_cols, self.tablename) + assert self.hash_fields is not None, 'hash_fields must not be empty' + fields = [ + 'cast({} as string)'.format(key) for key in self.hash_fields.split(',') + ] + str_fields = ','.join(fields) + if not part: + sql += """ + where hash(concat({}))%{}={} + """.format(str_fields, self.task_num, self.task_index) + else: + sql += """ + where {} and hash(concat({}))%{}={} + """.format(part, str_fields, self.task_num, self.task_index) + if self.limit_num is not None and self.limit_num > 0: + sql += ' limit {}'.format(self.limit_num) + return sql + + +class HiveUtils(object): + """Common IO based interface, could run at local or on data science.""" + + def __init__(self, + data_config, + hive_config, + mode, + selected_cols, + record_defaults, + input_path=None, + task_index=0, + task_num=1): + + self._data_config = data_config + self._hive_config = hive_config + self._eval_batch_size = data_config.eval_batch_size + self._fetch_size = self._hive_config.fetch_size + self._this_batch_size = self._get_batch_size(mode) + + self._num_epoch = data_config.num_epochs + self._num_epoch_record = 1 + self._task_index = task_index + self._task_num = task_num + self._input_path = input_path.split(',') + self._selected_cols = selected_cols + self._record_defaults = record_defaults + + def _construct_table_info(self, table_name, hash_fields, limit_num): + # sample_table/dt=2014-11-23/name=a + segs = table_name.split('/') + table_name = segs[0].strip() + if len(segs) > 0: + partition_kv = {i.split('=')[0]: i.split('=')[1] for i in segs[1:]} + else: + partition_kv = None + + table_info = TableInfo(table_name, self._selected_cols, partition_kv, + hash_fields, limit_num, self._data_config.batch_size, + self._task_index, self._task_num, self._num_epoch) + return table_info + + def _construct_hive_connect(self): + conn = hive.Connection( + host=self._hive_config.host, + port=self._hive_config.port, + username=self._hive_config.username, + database=self._hive_config.database) + return conn + + def _get_batch_size(self, mode): + if mode == tf.estimator.ModeKeys.TRAIN: + return self._data_config.batch_size + else: + return self._eval_batch_size + + def _hive_read(self): + logging.info('start epoch[%d]' % self._num_epoch_record) + self._num_epoch_record += 1 + + for table_path in self._input_path: + table_info = self._construct_table_info(table_path, + self._hive_config.hash_fields, + self._hive_config.limit_num) + batch_size = self._this_batch_size + batch_defaults = [] + for x in self._record_defaults: + if isinstance(x, str): + batch_defaults.append(np.array([x] * batch_size, dtype='S500')) + else: + batch_defaults.append(np.array([x] * batch_size)) + + row_id = 0 + batch_data_np = [x.copy() for x in batch_defaults] + + conn = self._construct_hive_connect() + cursor = conn.cursor() + sql = table_info.gen_sql() + cursor.execute(sql) + + while True: + data = cursor.fetchmany(size=self._fetch_size) + if len(data) == 0: + break + for rows in data: + for col_id in range(len(self._record_defaults)): + if rows[col_id] not in ['', 'NULL', None]: + batch_data_np[col_id][row_id] = rows[col_id] + else: + batch_data_np[col_id][row_id] = batch_defaults[col_id][row_id] + row_id += 1 + + if row_id >= batch_size: + yield tuple(batch_data_np) + row_id = 0 + + if row_id > 0: + yield tuple([x[:row_id] for x in batch_data_np]) + cursor.close() + conn.close() + logging.info('finish epoch[%d]' % self._num_epoch_record) diff --git a/easy_rec/python/utils/tf_utils.py b/easy_rec/python/utils/tf_utils.py new file mode 100644 index 000000000..a17cceb91 --- /dev/null +++ b/easy_rec/python/utils/tf_utils.py @@ -0,0 +1,22 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +"""Common functions used for odps input.""" +import tensorflow as tf + +from easy_rec.python.protos.dataset_pb2 import DatasetConfig + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +def get_tf_type(field_type): + type_map = { + DatasetConfig.INT32: tf.int32, + DatasetConfig.INT64: tf.int64, + DatasetConfig.STRING: tf.string, + DatasetConfig.BOOL: tf.bool, + DatasetConfig.FLOAT: tf.float32, + DatasetConfig.DOUBLE: tf.double + } + assert field_type in type_map, 'invalid type: %s' % field_type + return type_map[field_type] diff --git a/easy_rec/version.py b/easy_rec/version.py index d9ccf7a9f..46b9af067 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,3 +1,3 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.4.6' +__version__ = '0.4.7' diff --git a/pai_jobs/run.py b/pai_jobs/run.py index 0d3d53347..6a37456bc 100644 --- a/pai_jobs/run.py +++ b/pai_jobs/run.py @@ -10,7 +10,7 @@ import tensorflow as tf import easy_rec -from easy_rec.python.inference.predictor import Predictor +from easy_rec.python.inference.predictor import ODPSPredictor from easy_rec.python.inference.vector_retrieve import VectorRetrieve from easy_rec.python.tools.pre_check import run_check from easy_rec.python.utils import config_util @@ -166,6 +166,8 @@ 'hyperparameter save metric path') tf.app.flags.DEFINE_string('asset_files', None, 'extra files to add to export') tf.app.flags.DEFINE_bool('check_mode', False, 'is use check mode') +tf.app.flags.DEFINE_string('fg_json_path', None, '') + FLAGS = tf.app.flags.FLAGS @@ -409,7 +411,12 @@ def main(argv): profiling_file = FLAGS.profiling_file if FLAGS.task_index == 0 else None if profiling_file is not None: print('profiling_file = %s ' % profiling_file) - predictor = Predictor(FLAGS.saved_model_dir, profiling_file=profiling_file) + predictor = ODPSPredictor( + FLAGS.saved_model_dir, + fg_json_path=FLAGS.fg_json_path, + profiling_file=profiling_file, + all_cols=FLAGS.all_cols, + all_col_types=FLAGS.all_col_types) input_table, output_table = FLAGS.tables, FLAGS.outputs logging.info('input_table = %s, output_table = %s' % (input_table, output_table)) @@ -417,9 +424,6 @@ def main(argv): predictor.predict_impl( input_table, output_table, - all_cols=FLAGS.all_cols, - all_col_types=FLAGS.all_col_types, - selected_cols=FLAGS.selected_cols, reserved_cols=FLAGS.reserved_cols, output_cols=FLAGS.output_cols, batch_size=FLAGS.batch_size,