Skip to content

Commit

Permalink
【PIR OpTest Fix No.6】 fix test_tdm_sampler_op (#59617)
Browse files Browse the repository at this point in the history
* fix test_tdm_sampler_op

* test infer_meta

* Update op_test.py

* Update op_test.py

* Update op_test.py

* Update op_test.py

* Update ternary.cc

* set dtype

* fix
  • Loading branch information
xingmingyyj authored Jan 16, 2024
1 parent 446f7a7 commit 5e8349f
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
'shadow_feed',
'shuffle_batch',
'sparse_momentum',
'tdm_sampler',
'soft_relu',
'uniform_random_batch_size_like',
'c_reduce_min',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,16 @@
inplace : (mean -> mean_out), (variance -> variance_out)
optional : reserve_space

- op : tdm_sampler
args: (Tensor x, Tensor travel, Tensor layer, bool output_positive=true, int[] neg_samples_num_list={}, int[] layer_offset_lod={}, int seed = 0, int dtype=2)
output: Tensor(out), Tensor(labels), Tensor(mask)
infer_meta:
func : TdmSamplerInferMeta
kernel:
func : tdm_sampler
data_type : x
optional : labels

- op : tile
args : (Tensor x, IntArray repeat_times = {})
output : Tensor(out)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ const std::unordered_set<std::string> LegacyOpList = {
ShareDataOp::name(),
SparseMomentumOp::name(),
GetTensorFromSelectedRowsOp::name(),
TdmSamplerOp::name(),
RowConvOp::name(),
RowConvGradOp::name(),
SoftReluOp::name(),
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3158,6 +3158,12 @@
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]

- op : tdm_sampler
inputs:
{x : X, travel : Travel, layer : Layer}
outputs:
{out : Out, labels : Labels, mask : Mask}

- op : thresholded_relu
inputs :
x : X
Expand Down
41 changes: 41 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1525,5 +1525,46 @@ void QuantLinearInferMeta(const MetaTensor& x,
y->share_lod(x);
y->set_dtype(x.dtype());
}
void TdmSamplerInferMeta(const MetaTensor& x,
const MetaTensor& travel,
const MetaTensor& layer,
bool output_positive,
const std::vector<int>& neg_samples_num_list,
const std::vector<int>& layer_offset_lod,
int seed,
int dtype,
MetaTensor* out,
MetaTensor* labels,
MetaTensor* mask,
MetaConfig config) {
auto neg_samples_num_vec = neg_samples_num_list;
auto output_positive_flag = output_positive;

int64_t sample_res_length = 0;
for (auto sample_nums : neg_samples_num_vec) {
sample_res_length += sample_nums + (int64_t)output_positive_flag;
}
auto ddim = phi::make_ddim({-1, sample_res_length});
auto input_dims = x.dims();
if (config.is_runtime) {
auto output_dims = phi::vectorize(input_dims);
auto batch_size = output_dims[0];
out->set_dims(phi::make_ddim({batch_size, sample_res_length}));
mask->set_dims(phi::make_ddim({batch_size, sample_res_length}));
if (labels) {
labels->set_dims(phi::make_ddim({batch_size, sample_res_length}));
}
} else {
out->set_dims(ddim);
mask->set_dims(ddim);
if (labels) {
labels->set_dims(ddim);
}
}
out->set_dtype(x.dtype());
mask->set_dtype(x.dtype());
if (labels) {
labels->set_dtype(x.dtype());
}
}
} // namespace phi
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,17 @@ void QuantLinearInferMeta(const MetaTensor& x,
float quant_min_bound,
MetaTensor* y);

void TdmSamplerInferMeta(const MetaTensor& x,
const MetaTensor& travel,
const MetaTensor& layer,
bool output_positive,
const std::vector<int>& neg_samples_num_list,
const std::vector<int>& layer_offset_lod,
int seed,
int dtype,
MetaTensor* out,
MetaTensor* labels,
MetaTensor* mask,
MetaConfig config = MetaConfig());

} // namespace phi
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_no_check_list
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ test_poisson_op
test_randint_op
test_randperm_op
test_seed_op
test_tdm_sampler_op
test_uniform_random_bf16_op
test_uniform_random_inplace_op
test_shuffle_batch_op
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ test_squeeze2_op
test_sum_mkldnn_op
test_svd_op
test_take_along_axis_op
test_tdm_sampler_op
test_temporal_shift_op
test_tile_op
test_top_k_v2_op
Expand Down

0 comments on commit 5e8349f

Please sign in to comment.