diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 79cbad13c0f56..5da56c2093116 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -127,6 +127,7 @@ 'fused_scale_bias_add_relu', 'fused_dconv_drelu_dbn', 'fused_dot_product_attention', + 'nce', 'lars_momentum', 'recv_v2', 'rnn_', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index b992c139b8543..8b1f830b63152 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1471,6 +1471,17 @@ data_type: param optional: master_param, master_param_out +- op: nce + args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) + output: Tensor(cost), Tensor(sample_logits), Tensor(sample_labels) + infer_meta: + func: NceInferMeta + kernel: + func: nce + data_type: input + optional: bias, sample_weight, custom_dist_probs, custom_dist_alias, custom_dist_alias_probs + backward: nce_grad + - op: number_count args: (Tensor numbers, int upper_range) output: Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index bf0b939267e1b..afa544626369f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -503,6 +503,18 @@ func : multiply_triple_grad optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad +- backward_op : nce_grad + forward: nec (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) -> Tensor(cost), Tensor(sample_logits), Tensor(sample_labels) + args: (Tensor input, Tensor label, Tensor bias, Tensor weight, Tensor sample_logits, Tensor sample_labels, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, Tensor cost_grad, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) + output: Tensor(input_grad), Tensor(bias_grad), Tensor(weight_grad) + infer_meta: + func: NceGradInferMeta + param: [input, bias, weight] + kernel: + func: nce_grad + data_type: input + optional: bias, sample_weight, custom_dist_probs, custom_dist_alias, custom_dist_alias_probs + - backward_op : norm_grad forward : norm (Tensor x, int axis, float epsilon, bool is_test) -> Tensor(out), Tensor(norm) args : (Tensor x, Tensor norm, Tensor out_grad, int axis, float epsilon, bool is_test) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index ebc1615a16d51..68b8edbd6da8a 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -59,6 +59,8 @@ const std::unordered_set LegacyOpList = { RowConvGradOp::name(), SoftReluOp::name(), SoftReluGradOp::name(), + NceOp::name(), + NceGradOp::name(), CReduceMinOp::name()}; const std::unordered_set OneDNNLegacyOpList = {}; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index e605dab154337..be902f4317a0e 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3510,6 +3510,13 @@ outputs : out : Out +- op: nce + backward: nce_grad + inputs: + {input : Input, label : Label, weight : Weight, bias : Bias, sample_weight : SampleWeight, custom_dist_probs : CustomDistProbs, custom_dist_alias : CustomDistAlias, custom_dist_alias_probs : CustomDistAliasProbs} + outputs: + {cost : Cost, sample_logits : SampleLogits, sample_labels : SampleLabels} + - op: number_count inputs : {numbers: numbers} diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index ee2388762668b..b310bef397e06 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -818,6 +818,33 @@ void NanmedianGradInferMeta(const MetaTensor& x, x_grad->set_dtype(x.dtype()); } +void NceGradInferMeta(const MetaTensor& input, + const MetaTensor& bias, + const MetaTensor& weight, + MetaTensor* input_grad, + MetaTensor* bias_grad, + MetaTensor* weight_grad + +) { + auto x_dims = input.dims(); + if (input_grad != nullptr) { + input_grad->set_dims(x_dims); + input_grad->set_dtype(input.dtype()); + } + + auto w_dims = weight.dims(); + if (weight_grad) { + weight_grad->set_dims(w_dims); + weight_grad->set_dtype(weight.dtype()); + } + + auto bias_dims = bias.dims(); + if (bias_grad) { + bias_grad->set_dims(bias_dims); + bias_grad->set_dtype(bias.dtype()); + } +} + void NllLossGradInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 922bafed0add8..b37e1b3d7b84b 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -361,6 +361,13 @@ void NanmedianGradInferMeta(const MetaTensor& x, bool keep_dim, MetaTensor* x_grad); +void NceGradInferMeta(const MetaTensor& input, + const MetaTensor& bias, + const MetaTensor& weight, + MetaTensor* input_grad, + MetaTensor* bias_grad, + MetaTensor* weight_grad); + void NllLossGradInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 6250b3a3b23c8..f8d82a749b3d8 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3181,6 +3181,98 @@ void MultiplexInferMeta(const std::vector& ins, out->set_dtype(ins[0]->dtype()); } +void NceInferMeta(const MetaTensor& input, + const MetaTensor& label, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& sample_weight, + const MetaTensor& custom_dist_probs, + const MetaTensor& custom_dist_alias, + const MetaTensor& custom_dist_alias_probs, + int num_total_classes, + const std::vector& custom_neg_classes, + int num_neg_samples, + int sampler, + int seed, + bool is_sparse, + bool remote_prefetch, + bool is_test, + MetaTensor* cost, + MetaTensor* sample_logits, + MetaTensor* sample_labels, + MetaConfig config) { + auto x_dims = input.dims(); + auto label_dims = label.dims(); + if (config.is_runtime || (x_dims[0] > 0 && label_dims[0] > 0)) { + PADDLE_ENFORCE_EQ( + x_dims[0], + label_dims[0], + phi::errors::InvalidArgument( + "The first dimension of Input(Input) and Input(Label) should be " + "equal in runtime. But received: Input(Input)'s shape = [%s] " + "with 1st dim = %d, Input(Label)'s shape = [%s] with 1st dim = " + "%d.", + x_dims, + x_dims[0], + label_dims, + label_dims[0])); + } + int num_true_classes = + static_cast(label_dims.size() == 2 ? label_dims[1] : 1); + if (bias) { + PADDLE_ENFORCE_EQ( + weight.dims()[0], + bias.dims()[0], + phi::errors::InvalidArgument( + "The first dimension of Input(Weight) and Input(Bias) " + "should be equal. But received: Input(Weight)'s shape = [%s] " + "with 1st dim = %d, and Input(Bias)'s shape = [%s] with 1st dim " + "= %d.", + weight.dims(), + weight.dims()[0], + bias.dims(), + bias.dims()[0])); + } + + PADDLE_ENFORCE_EQ( + num_total_classes, + weight.dims()[0], + phi::errors::InvalidArgument( + "The number of total classes should be equal to the first " + "dimension of Input(Weight). But received: Attr(num_total_classes) " + "= %d, Input(Weight)'s shape = [%s] with 1st dim = %d.", + num_total_classes, + weight.dims(), + weight.dims()[0])); + if (custom_neg_classes.size() > 0) { + PADDLE_ENFORCE_EQ( + custom_neg_classes.size(), + static_cast(num_neg_samples), + phi::errors::InvalidArgument( + "The size of Attr(custom_neg_classes) should be equal " + "to the number of negative samples. But received: " + "custom_neg_classes.size() = %d, num_neg_samples = %d.", + custom_neg_classes.size(), + num_neg_samples)); + } + // set dims of output(Out) + std::vector out_dims; + out_dims.push_back(x_dims[0]); + out_dims.push_back(1); + cost->set_dims(common::make_ddim(out_dims)); + cost->set_dtype(DataType::FLOAT32); + + if (!is_test) { + // set dims of output(SampleOut) + std::vector sample_out_dims; + sample_out_dims.push_back(x_dims[0]); + sample_out_dims.push_back( + (num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes)); + sample_logits->set_dims(common::make_ddim(sample_out_dims)); + sample_labels->set_dims(common::make_ddim(sample_out_dims)); + } +} + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index f51c3dacb1909..ebda904e4d4a7 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -569,6 +569,27 @@ void MultiplexInferMeta(const std::vector& ins, const MetaTensor& ids, MetaTensor* out); +void NceInferMeta(const MetaTensor& input, + const MetaTensor& label, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& sample_weight, + const MetaTensor& custom_dist_probs, + const MetaTensor& custom_dist_alias, + const MetaTensor& custom_dist_alias_probs, + int num_total_classes, + const std::vector& custom_neg_classes, + int num_neg_samples, + int sampler, + int seed, + bool is_sparse, + bool remote_prefetch, + bool is_test, + MetaTensor* cost, + MetaTensor* sample_logits, + MetaTensor* sample_labels, + MetaConfig config = MetaConfig()); + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 045e8b4df9459..503426b971ae0 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -210,6 +210,7 @@ test_multinomial_op test_multiplex_op test_mv_op test_nanmedian +test_nce test_nearest_interp_mkldnn_op test_nearest_interp_v2_op test_nextafter_op