From 5e8349f949429e3c8e602263d21e508f95c2c095 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:40:43 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20OpTest=20Fix=20No.6=E3=80=91=20f?= =?UTF-8?q?ix=20test=5Ftdm=5Fsampler=5Fop=20(#59617)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix test_tdm_sampler_op * test infer_meta * Update op_test.py * Update op_test.py * Update op_test.py * Update op_test.py * Update ternary.cc * set dtype * fix --- .../pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 10 +++++ .../fluid/pir/dialect/operator/utils/utils.cc | 1 + paddle/phi/api/yaml/op_compat.yaml | 6 +++ paddle/phi/infermeta/ternary.cc | 41 +++++++++++++++++++ paddle/phi/infermeta/ternary.h | 13 ++++++ test/white_list/pir_op_test_no_check_list | 1 + test/white_list/pir_op_test_white_list | 1 + 8 files changed, 74 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 38fc56de5a0bc..d2a91bdc4eb6f 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -143,6 +143,7 @@ 'shadow_feed', 'shuffle_batch', 'sparse_momentum', + 'tdm_sampler', 'soft_relu', 'uniform_random_batch_size_like', 'c_reduce_min', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 7ead31ee570db..fd62878066b38 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1310,6 +1310,16 @@ inplace : (mean -> mean_out), (variance -> variance_out) optional : reserve_space +- op : tdm_sampler + args: (Tensor x, Tensor travel, Tensor layer, bool output_positive=true, int[] neg_samples_num_list={}, int[] layer_offset_lod={}, int seed = 0, int dtype=2) + output: Tensor(out), Tensor(labels), Tensor(mask) + infer_meta: + func : TdmSamplerInferMeta + kernel: + func : tdm_sampler + data_type : x + optional : labels + - op : tile args : (Tensor x, IntArray repeat_times = {}) output : Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 14e79508e5c47..2b72adb0c50fd 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -62,6 +62,7 @@ const std::unordered_set LegacyOpList = { ShareDataOp::name(), SparseMomentumOp::name(), GetTensorFromSelectedRowsOp::name(), + TdmSamplerOp::name(), RowConvOp::name(), RowConvGradOp::name(), SoftReluOp::name(), diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 2b207c8a94c63..1c952b2b23327 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3158,6 +3158,12 @@ extra : attrs : [bool use_mkldnn = false, bool use_cudnn = false] +- op : tdm_sampler + inputs: + {x : X, travel : Travel, layer : Layer} + outputs: + {out : Out, labels : Labels, mask : Mask} + - op : thresholded_relu inputs : x : X diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 27fe7dc19ae4c..2f53eb9ef7199 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1525,5 +1525,46 @@ void QuantLinearInferMeta(const MetaTensor& x, y->share_lod(x); y->set_dtype(x.dtype()); } +void TdmSamplerInferMeta(const MetaTensor& x, + const MetaTensor& travel, + const MetaTensor& layer, + bool output_positive, + const std::vector& neg_samples_num_list, + const std::vector& layer_offset_lod, + int seed, + int dtype, + MetaTensor* out, + MetaTensor* labels, + MetaTensor* mask, + MetaConfig config) { + auto neg_samples_num_vec = neg_samples_num_list; + auto output_positive_flag = output_positive; + int64_t sample_res_length = 0; + for (auto sample_nums : neg_samples_num_vec) { + sample_res_length += sample_nums + (int64_t)output_positive_flag; + } + auto ddim = phi::make_ddim({-1, sample_res_length}); + auto input_dims = x.dims(); + if (config.is_runtime) { + auto output_dims = phi::vectorize(input_dims); + auto batch_size = output_dims[0]; + out->set_dims(phi::make_ddim({batch_size, sample_res_length})); + mask->set_dims(phi::make_ddim({batch_size, sample_res_length})); + if (labels) { + labels->set_dims(phi::make_ddim({batch_size, sample_res_length})); + } + } else { + out->set_dims(ddim); + mask->set_dims(ddim); + if (labels) { + labels->set_dims(ddim); + } + } + out->set_dtype(x.dtype()); + mask->set_dtype(x.dtype()); + if (labels) { + labels->set_dtype(x.dtype()); + } +} } // namespace phi diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 7272941504ff2..36ed324ee1d6c 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -238,4 +238,17 @@ void QuantLinearInferMeta(const MetaTensor& x, float quant_min_bound, MetaTensor* y); +void TdmSamplerInferMeta(const MetaTensor& x, + const MetaTensor& travel, + const MetaTensor& layer, + bool output_positive, + const std::vector& neg_samples_num_list, + const std::vector& layer_offset_lod, + int seed, + int dtype, + MetaTensor* out, + MetaTensor* labels, + MetaTensor* mask, + MetaConfig config = MetaConfig()); + } // namespace phi diff --git a/test/white_list/pir_op_test_no_check_list b/test/white_list/pir_op_test_no_check_list index 00d8f054df131..6799c396a517d 100644 --- a/test/white_list/pir_op_test_no_check_list +++ b/test/white_list/pir_op_test_no_check_list @@ -7,6 +7,7 @@ test_poisson_op test_randint_op test_randperm_op test_seed_op +test_tdm_sampler_op test_uniform_random_bf16_op test_uniform_random_inplace_op test_shuffle_batch_op diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 7e652790bdece..fdaa016752a4e 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -254,6 +254,7 @@ test_squeeze2_op test_sum_mkldnn_op test_svd_op test_take_along_axis_op +test_tdm_sampler_op test_temporal_shift_op test_tile_op test_top_k_v2_op