diff --git a/paddle/fluid/operators/match_matrix_tensor_op.cc b/paddle/fluid/operators/match_matrix_tensor_op.cc index 7055f3ca95efe..dd2dcd548180a 100644 --- a/paddle/fluid/operators/match_matrix_tensor_op.cc +++ b/paddle/fluid/operators/match_matrix_tensor_op.cc @@ -251,6 +251,93 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { auto* tmp = ctx.Output("Tmp"); int dim_t = ctx.Attr("dim_t"); + + const auto& x_lod = x->lod(); + PADDLE_ENFORCE_EQ(x_lod.empty(), + false, + platform::errors::InvalidArgument( + "The Input(X) should hold LoD information, but " + "received Input(X).lod() is empty.")); + const auto& x_lod_0 = x_lod[0]; + PADDLE_ENFORCE_GE(x_lod_0.size(), + 2, + platform::errors::InvalidArgument( + "The dimensions of Input(X)'s LoD data should be " + "equal to 2, but received %d.", + x_lod_0.size())); + auto x_dims = x->dims(); + PADDLE_ENFORCE_EQ(x_dims[0], + static_cast(x_lod_0.back()), + platform::errors::InvalidArgument( + "The last element of Input(X)'s LoD data should be " + "equal to the first dimension of Input(X). " + "But received the last element of Input(X)'s LoD " + "data is %d, the first dimension of Input(X) is %d.", + x_lod_0.back(), + x_dims[0])); + const auto& y_lod = y->lod(); + PADDLE_ENFORCE_EQ(y_lod.empty(), + false, + platform::errors::InvalidArgument( + "The Input(Y) should hold LoD information, but " + "received Input(Y).lod() is empty.")); + const auto& y_lod_0 = y_lod[0]; + PADDLE_ENFORCE_GE(y_lod_0.size(), + 2, + platform::errors::InvalidArgument( + "The dimensions of Input(Y)'s LoD data should be " + "equal to 2, but received %d.", + y_lod_0.size())); + auto y_dims = y->dims(); + PADDLE_ENFORCE_EQ(y_dims[0], + static_cast(y_lod_0.back()), + platform::errors::InvalidArgument( + "The last element of Input(Y)'s LoD data should be " + "equal to the first dimension of Input(Y). " + "But received the last element of Input(Y)'s LoD " + "data is %d, the first dimension of Input(Y) is %d.", + y_lod_0.back(), + y_dims[0])); + + PADDLE_ENFORCE_EQ(x_lod_0.size(), + y_lod_0.size(), + platform::errors::InvalidArgument( + "The dimensions of Input(X)'s and Input(Y)'s LoD " + "data should be equal. " + "But received the dimensions of Input(X)'s LoD is " + "%d, the dimensions of Input(Y)'s LoD is %d.", + x_lod_0.size(), + y_lod_0.size())); + + int64_t out_dim_0 = 0; + int64_t tmp_dim_0 = -1; + for (size_t i = 1; i < x_lod_0.size(); i++) { + int64_t x_len = x_lod_0[i] - x_lod_0[i - 1]; + int64_t y_len = y_lod_0[i] - y_lod_0[i - 1]; + out_dim_0 += (x_len * y_len); + } + out_dim_0 *= dim_t; + + tmp_dim_0 = x_dims[0] * dim_t * x_dims[1]; + std::vector out_dims_vec{out_dim_0}; + out_dims_vec.push_back(1); + std::vector tmp_dims_vec{tmp_dim_0}; + tmp_dims_vec.push_back(1); + + auto& out_meta = out->meta(); + phi::DenseTensorMeta new_out_meta(out_meta.dtype, + common::make_ddim(out_dims_vec), + out_meta.layout, + out_meta.lod); + out->set_meta(new_out_meta); + + auto& tmp_meta = tmp->meta(); + phi::DenseTensorMeta new_tmp_meta(tmp_meta.dtype, + common::make_ddim(tmp_dims_vec), + tmp_meta.layout, + tmp_meta.lod); + tmp->set_meta(new_tmp_meta); + int64_t dim_in = x->dims()[1]; const auto& offset_l = x->lod()[0]; diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 314f4b343d481..74e9d57c9d97a 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -141,6 +141,7 @@ 'sparse_momentum', 'soft_relu', 'uniform_random_batch_size_like', + 'match_matrix_tensor', 'c_reduce_min', 'c_reduce_min_', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 97fa1a6879e0a..51efa8a076a96 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1491,6 +1491,15 @@ data_type: param optional: master_param, master_param_out +- op: match_matrix_tensor + args: (Tensor x, Tensor y, Tensor w, int dim_t=1) + output: Tensor(out), Tensor(tmp) + infer_meta: + func: MatchMatrixTensorInferMeta + kernel: + func: match_matrix_tensor + backward: match_matrix_tensor_grad + - op: nce args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) output: Tensor(cost), Tensor(sample_logits), Tensor(sample_labels) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 71124cf559396..0f98f707c1247 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -938,6 +938,15 @@ func: fused_elemwise_add_activation_grad optional : x, intermediate_out +- backward_op: match_matrix_tensor_grad + forward: match_matrix_tensor (Tensor x, Tensor y, Tensor w, int dim_t=1) -> Tensor(out), Tensor(tmp) + args: (Tensor x, Tensor y, Tensor w, Tensor tmp, Tensor out_grad, int dim_t=1) + output: Tensor(x_grad), Tensor(y_grad), Tensor(w_grad) + infer_meta: + func: MatchMatrixTensorGradInferMeta + kernel: + func: match_matrix_tensor_grad + - backward_op: shuffle_batch_grad forward: shuffle_batch (Tensor x, Tensor seed, int startup_seed=0) -> Tensor(out), Tensor(shuffle_idx), Tensor(seed_out) args: (Tensor shuffle_idx, Tensor out_grad,int startup_seed=0) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index d0e800f157c5f..b8da2b6cb8ba1 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -60,6 +60,8 @@ const std::unordered_set LegacyOpList = { RowConvGradOp::name(), SoftReluOp::name(), SoftReluGradOp::name(), + MatchMatrixTensorOp::name(), + MatchMatrixTensorGradOp::name(), NceOp::name(), NceGradOp::name(), CReduceMinOp::name()}; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 7071df37e4aa5..31bde1a6b7858 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3511,6 +3511,13 @@ attrs: pivot : pivots +- op: match_matrix_tensor + backward: match_matrix_tensor_grad + inputs: + {x : X, y : Y, w : W} + outputs: + {out : Out, tmp : Tmp} + - op: memcpy inputs: x: X diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 6d33afeb94898..8d67620fc49b8 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -662,6 +662,31 @@ void MarginCrossEntropyGradInferMeta(const MetaTensor& logits, logits_grad->set_dtype(softmax.dtype()); } +void MatchMatrixTensorGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + const MetaTensor& tmp, + const MetaTensor& out_grad, + int dim_t, + MetaTensor* x_grad, + MetaTensor* y_grad, + MetaTensor* w_grad) { + if (x_grad != nullptr) { + x_grad->set_dims(x.dims()); + x_grad->share_lod(x); + x_grad->set_dtype(x.dtype()); + } + if (y_grad != nullptr) { + y_grad->set_dims(y.dims()); + y_grad->share_lod(y); + y_grad->set_dtype(y.dtype()); + } + if (w_grad != nullptr) { + w_grad->set_dims(w.dims()); + w_grad->set_dtype(w.dtype()); + } +} + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, @@ -1352,5 +1377,4 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, value_grad->share_lod(values); } } - } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 3112fa8b9ddad..f1458f3f4b8fd 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -312,6 +312,16 @@ void MarginCrossEntropyGradInferMeta(const MetaTensor& logits, float scale, MetaTensor* logits_grad); +void MatchMatrixTensorGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + const MetaTensor& tmp, + const MetaTensor& out_grad, + int dim_t, + MetaTensor* x_grad, + MetaTensor* y_grad, + MetaTensor* w_grad); + void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const MetaTensor& mask, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 27fe7dc19ae4c..6c18cbf19e8b1 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -725,6 +725,79 @@ void LinspaceInferMeta(const MetaTensor& start, LinspaceRawInferMeta(start, stop, number, out); } +void MatchMatrixTensorInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + int dim_t, + MetaTensor* out, + MetaTensor* tmp, + MetaConfig config) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The dimensions of Input(X) should be equal to 2, " + "but received %d.", + x_dims.size())); + + auto y_dims = y.dims(); + PADDLE_ENFORCE_EQ(y_dims.size(), + 2, + phi::errors::InvalidArgument( + "The dimensions of Input(Y) should be equal to 2, " + "but received %d.", + y_dims.size())); + + auto w_dims = w.dims(); + PADDLE_ENFORCE_EQ(w_dims.size(), + 3, + phi::errors::InvalidArgument( + "The dimensions of Input(W) should be equal to 3, " + "but received %d.", + w_dims.size())); + + PADDLE_ENFORCE_EQ( + w_dims[0], + x_dims[1], + phi::errors::InvalidArgument( + "The first dimension of Input(W) should be equal to the second " + "dimension of Input(X). But received the first dimension of Input(W) " + "is %d, the second dimension of Input(X) is %d.", + w_dims[0], + x_dims[1])); + PADDLE_ENFORCE_EQ( + w_dims[1], + dim_t, + phi::errors::InvalidArgument( + "The second dimension of Input(W) should be equal to 'dim_t', but " + "received the second dimension of Input(W) is %d, 'dim_t' is %d.", + w_dims[1], + dim_t)); + PADDLE_ENFORCE_EQ( + w_dims[2], + y_dims[1], + phi::errors::InvalidArgument( + "The last dimension of Input(W) should be equal to " + "the second dimension of Input(Y). But received the last dimension " + "of Input(W) is %d, the second dimension of Input(Y) is %d.", + w_dims[2], + y_dims[1])); + + int64_t out_dim_0 = -1; + int64_t tmp_dim_0 = -1; + if (!config.is_runtime) { + out->share_lod(x); + std::vector out_dims_vec{out_dim_0}; + out_dims_vec.push_back(1); + std::vector tmp_dims_vec{tmp_dim_0}; + tmp_dims_vec.push_back(1); + out->set_dims(common::make_ddim(out_dims_vec)); + out->set_dtype(x.dtype()); + tmp->set_dims(common::make_ddim(tmp_dims_vec)); + tmp->set_dtype(x.dtype()); + } +} + void MultiClassNMSInferMeta(const MetaTensor& bboxes, const MetaTensor& scores, const MetaTensor& rois_num, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 7272941504ff2..fee8718b04d7e 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -132,6 +132,14 @@ void LinspaceInferMeta(const MetaTensor& start, DataType dtype, MetaTensor* out); +void MatchMatrixTensorInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + int dim_t, + MetaTensor* out, + MetaTensor* tmp, + MetaConfig config = MetaConfig()); + void MultiClassNMSInferMeta(const MetaTensor& bboxes, const MetaTensor& scores, const MetaTensor& rois_num, diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 7e652790bdece..12deff35f1ea4 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -165,6 +165,7 @@ test_lu_op test_lu_unpack_op test_margin_cross_entropy_op test_masked_select_op +test_match_matrix_tensor_op test_matmul_v2_op test_matmul_v2_op_static_build test_matrix_nms_op